1 /******************************************************************************
2  * Copyright (c) 2011, Duane Merrill.  All rights reserved.
3  * Copyright (c) 2011-2018, NVIDIA CORPORATION.  All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are met:
7  *     * Redistributions of source code must retain the above copyright
8  *       notice, this list of conditions and the following disclaimer.
9  *     * Redistributions in binary form must reproduce the above copyright
10  *       notice, this list of conditions and the following disclaimer in the
11  *       documentation and/or other materials provided with the distribution.
12  *     * Neither the name of the NVIDIA CORPORATION nor the
13  *       names of its contributors may be used to endorse or promote products
14  *       derived from this software without specific prior written permission.
15  *
16  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
17  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19  * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
20  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26  *
27  ******************************************************************************/
28 
29 /**
30  * \file
31  * cub::BlockRadixRank provides operations for ranking unsigned integer types within a CUDA thread block
32  */
33 
34 #pragma once
35 
36 #include <stdint.h>
37 
38 #include "../thread/thread_reduce.cuh"
39 #include "../thread/thread_scan.cuh"
40 #include "../block/block_scan.cuh"
41 #include "../util_ptx.cuh"
42 #include "../util_arch.cuh"
43 #include "../util_type.cuh"
44 #include "../util_namespace.cuh"
45 
46 
47 /// Optional outer namespace(s)
48 CUB_NS_PREFIX
49 
50 /// CUB namespace
51 namespace cub {
52 
53 /**
54  * \brief BlockRadixRank provides operations for ranking unsigned integer types within a CUDA thread block.
55  * \ingroup BlockModule
56  *
57  * \tparam BLOCK_DIM_X          The thread block length in threads along the X dimension
58  * \tparam RADIX_BITS           The number of radix bits per digit place
59  * \tparam IS_DESCENDING           Whether or not the sorted-order is high-to-low
60  * \tparam MEMOIZE_OUTER_SCAN   <b>[optional]</b> Whether or not to buffer outer raking scan partials to incur fewer shared memory reads at the expense of higher register pressure (default: true for architectures SM35 and newer, false otherwise).  See BlockScanAlgorithm::BLOCK_SCAN_RAKING_MEMOIZE for more details.
61  * \tparam INNER_SCAN_ALGORITHM <b>[optional]</b> The cub::BlockScanAlgorithm algorithm to use (default: cub::BLOCK_SCAN_WARP_SCANS)
62  * \tparam SMEM_CONFIG          <b>[optional]</b> Shared memory bank mode (default: \p cudaSharedMemBankSizeFourByte)
63  * \tparam BLOCK_DIM_Y          <b>[optional]</b> The thread block length in threads along the Y dimension (default: 1)
64  * \tparam BLOCK_DIM_Z          <b>[optional]</b> The thread block length in threads along the Z dimension (default: 1)
65  * \tparam PTX_ARCH             <b>[optional]</b> \ptxversion
66  *
67  * \par Overview
68  * Blah...
69  * - Keys must be in a form suitable for radix ranking (i.e., unsigned bits).
70  * - \blocked
71  *
72  * \par Performance Considerations
73  * - \granularity
74  *
75  * \par Examples
76  * \par
77  * - <b>Example 1:</b> Simple radix rank of 32-bit integer keys
78  *      \code
79  *      #include <cub/cub.cuh>
80  *
81  *      template <int BLOCK_THREADS>
82  *      __global__ void ExampleKernel(...)
83  *      {
84  *
85  *      \endcode
86  */
87 template <
88     int                     BLOCK_DIM_X,
89     int                     RADIX_BITS,
90     bool                    IS_DESCENDING,
91     bool                    MEMOIZE_OUTER_SCAN      = (CUB_PTX_ARCH >= 350) ? true : false,
92     BlockScanAlgorithm      INNER_SCAN_ALGORITHM    = BLOCK_SCAN_WARP_SCANS,
93     cudaSharedMemConfig     SMEM_CONFIG             = cudaSharedMemBankSizeFourByte,
94     int                     BLOCK_DIM_Y             = 1,
95     int                     BLOCK_DIM_Z             = 1,
96     int                     PTX_ARCH                = CUB_PTX_ARCH>
97 class BlockRadixRank
98 {
99 private:
100 
101     /******************************************************************************
102      * Type definitions and constants
103      ******************************************************************************/
104 
105     // Integer type for digit counters (to be packed into words of type PackedCounters)
106     typedef unsigned short DigitCounter;
107 
108     // Integer type for packing DigitCounters into columns of shared memory banks
109     typedef typename If<(SMEM_CONFIG == cudaSharedMemBankSizeEightByte),
110         unsigned long long,
111         unsigned int>::Type PackedCounter;
112 
113     enum
114     {
115         // The thread block size in threads
116         BLOCK_THREADS               = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z,
117 
118         RADIX_DIGITS                = 1 << RADIX_BITS,
119 
120         LOG_WARP_THREADS            = CUB_LOG_WARP_THREADS(PTX_ARCH),
121         WARP_THREADS                = 1 << LOG_WARP_THREADS,
122         WARPS                       = (BLOCK_THREADS + WARP_THREADS - 1) / WARP_THREADS,
123 
124         BYTES_PER_COUNTER           = sizeof(DigitCounter),
125         LOG_BYTES_PER_COUNTER       = Log2<BYTES_PER_COUNTER>::VALUE,
126 
127         PACKING_RATIO               = sizeof(PackedCounter) / sizeof(DigitCounter),
128         LOG_PACKING_RATIO           = Log2<PACKING_RATIO>::VALUE,
129 
130         LOG_COUNTER_LANES           = CUB_MAX((RADIX_BITS - LOG_PACKING_RATIO), 0),                // Always at least one lane
131         COUNTER_LANES               = 1 << LOG_COUNTER_LANES,
132 
133         // The number of packed counters per thread (plus one for padding)
134         PADDED_COUNTER_LANES        = COUNTER_LANES + 1,
135         RAKING_SEGMENT              = PADDED_COUNTER_LANES,
136     };
137 
138 public:
139 
140     enum
141     {
142         /// Number of bin-starting offsets tracked per thread
143         BINS_TRACKED_PER_THREAD = CUB_MAX(1, (RADIX_DIGITS + BLOCK_THREADS - 1) / BLOCK_THREADS),
144     };
145 
146 private:
147 
148 
149     /// BlockScan type
150     typedef BlockScan<
151             PackedCounter,
152             BLOCK_DIM_X,
153             INNER_SCAN_ALGORITHM,
154             BLOCK_DIM_Y,
155             BLOCK_DIM_Z,
156             PTX_ARCH>
157         BlockScan;
158 
159 
160     /// Shared memory storage layout type for BlockRadixRank
161     struct __align__(16) _TempStorage
162     {
163         union Aliasable
164         {
165             DigitCounter            digit_counters[PADDED_COUNTER_LANES][BLOCK_THREADS][PACKING_RATIO];
166             PackedCounter           raking_grid[BLOCK_THREADS][RAKING_SEGMENT];
167 
168         } aliasable;
169 
170         // Storage for scanning local ranks
171         typename BlockScan::TempStorage block_scan;
172     };
173 
174 
175     /******************************************************************************
176      * Thread fields
177      ******************************************************************************/
178 
179     /// Shared storage reference
180     _TempStorage &temp_storage;
181 
182     /// Linear thread-id
183     unsigned int linear_tid;
184 
185     /// Copy of raking segment, promoted to registers
186     PackedCounter cached_segment[RAKING_SEGMENT];
187 
188 
189     /******************************************************************************
190      * Utility methods
191      ******************************************************************************/
192 
193     /**
194      * Internal storage allocator
195      */
PrivateStorage()196     __device__ __forceinline__ _TempStorage& PrivateStorage()
197     {
198         __shared__ _TempStorage private_storage;
199         return private_storage;
200     }
201 
202 
203     /**
204      * Performs upsweep raking reduction, returning the aggregate
205      */
Upsweep()206     __device__ __forceinline__ PackedCounter Upsweep()
207     {
208         PackedCounter *smem_raking_ptr = temp_storage.aliasable.raking_grid[linear_tid];
209         PackedCounter *raking_ptr;
210 
211         if (MEMOIZE_OUTER_SCAN)
212         {
213             // Copy data into registers
214             #pragma unroll
215             for (int i = 0; i < RAKING_SEGMENT; i++)
216             {
217                 cached_segment[i] = smem_raking_ptr[i];
218             }
219             raking_ptr = cached_segment;
220         }
221         else
222         {
223             raking_ptr = smem_raking_ptr;
224         }
225 
226         return internal::ThreadReduce<RAKING_SEGMENT>(raking_ptr, Sum());
227     }
228 
229 
230     /// Performs exclusive downsweep raking scan
ExclusiveDownsweep(PackedCounter raking_partial)231     __device__ __forceinline__ void ExclusiveDownsweep(
232         PackedCounter raking_partial)
233     {
234         PackedCounter *smem_raking_ptr = temp_storage.aliasable.raking_grid[linear_tid];
235 
236         PackedCounter *raking_ptr = (MEMOIZE_OUTER_SCAN) ?
237             cached_segment :
238             smem_raking_ptr;
239 
240         // Exclusive raking downsweep scan
241         internal::ThreadScanExclusive<RAKING_SEGMENT>(raking_ptr, raking_ptr, Sum(), raking_partial);
242 
243         if (MEMOIZE_OUTER_SCAN)
244         {
245             // Copy data back to smem
246             #pragma unroll
247             for (int i = 0; i < RAKING_SEGMENT; i++)
248             {
249                 smem_raking_ptr[i] = cached_segment[i];
250             }
251         }
252     }
253 
254 
255     /**
256      * Reset shared memory digit counters
257      */
ResetCounters()258     __device__ __forceinline__ void ResetCounters()
259     {
260         // Reset shared memory digit counters
261         #pragma unroll
262         for (int LANE = 0; LANE < PADDED_COUNTER_LANES; LANE++)
263         {
264             *((PackedCounter*) temp_storage.aliasable.digit_counters[LANE][linear_tid]) = 0;
265         }
266     }
267 
268 
269     /**
270      * Block-scan prefix callback
271      */
272     struct PrefixCallBack
273     {
operator ()cub::BlockRadixRank::PrefixCallBack274         __device__ __forceinline__ PackedCounter operator()(PackedCounter block_aggregate)
275         {
276             PackedCounter block_prefix = 0;
277 
278             // Propagate totals in packed fields
279             #pragma unroll
280             for (int PACKED = 1; PACKED < PACKING_RATIO; PACKED++)
281             {
282                 block_prefix += block_aggregate << (sizeof(DigitCounter) * 8 * PACKED);
283             }
284 
285             return block_prefix;
286         }
287     };
288 
289 
290     /**
291      * Scan shared memory digit counters.
292      */
ScanCounters()293     __device__ __forceinline__ void ScanCounters()
294     {
295         // Upsweep scan
296         PackedCounter raking_partial = Upsweep();
297 
298         // Compute exclusive sum
299         PackedCounter exclusive_partial;
300         PrefixCallBack prefix_call_back;
301         BlockScan(temp_storage.block_scan).ExclusiveSum(raking_partial, exclusive_partial, prefix_call_back);
302 
303         // Downsweep scan with exclusive partial
304         ExclusiveDownsweep(exclusive_partial);
305     }
306 
307 public:
308 
309     /// \smemstorage{BlockScan}
310     struct TempStorage : Uninitialized<_TempStorage> {};
311 
312 
313     /******************************************************************//**
314      * \name Collective constructors
315      *********************************************************************/
316     //@{
317 
318     /**
319      * \brief Collective constructor using a private static allocation of shared memory as temporary storage.
320      */
BlockRadixRank()321     __device__ __forceinline__ BlockRadixRank()
322     :
323         temp_storage(PrivateStorage()),
324         linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z))
325     {}
326 
327 
328     /**
329      * \brief Collective constructor using the specified memory allocation as temporary storage.
330      */
BlockRadixRank(TempStorage & temp_storage)331     __device__ __forceinline__ BlockRadixRank(
332         TempStorage &temp_storage)             ///< [in] Reference to memory allocation having layout type TempStorage
333     :
334         temp_storage(temp_storage.Alias()),
335         linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z))
336     {}
337 
338 
339     //@}  end member group
340     /******************************************************************//**
341      * \name Raking
342      *********************************************************************/
343     //@{
344 
345     /**
346      * \brief Rank keys.
347      */
348     template <
349         typename        UnsignedBits,
350         int             KEYS_PER_THREAD>
RankKeys(UnsignedBits (& keys)[KEYS_PER_THREAD],int (& ranks)[KEYS_PER_THREAD],int current_bit,int num_bits)351     __device__ __forceinline__ void RankKeys(
352         UnsignedBits    (&keys)[KEYS_PER_THREAD],           ///< [in] Keys for this tile
353         int             (&ranks)[KEYS_PER_THREAD],          ///< [out] For each key, the local rank within the tile
354         int             current_bit,                        ///< [in] The least-significant bit position of the current digit to extract
355         int             num_bits)                           ///< [in] The number of bits in the current digit
356     {
357         DigitCounter    thread_prefixes[KEYS_PER_THREAD];   // For each key, the count of previous keys in this tile having the same digit
358         DigitCounter*   digit_counters[KEYS_PER_THREAD];    // For each key, the byte-offset of its corresponding digit counter in smem
359 
360         // Reset shared memory digit counters
361         ResetCounters();
362 
363         #pragma unroll
364         for (int ITEM = 0; ITEM < KEYS_PER_THREAD; ++ITEM)
365         {
366             // Get digit
367             unsigned int digit = BFE(keys[ITEM], current_bit, num_bits);
368 
369             // Get sub-counter
370             unsigned int sub_counter = digit >> LOG_COUNTER_LANES;
371 
372             // Get counter lane
373             unsigned int counter_lane = digit & (COUNTER_LANES - 1);
374 
375             if (IS_DESCENDING)
376             {
377                 sub_counter = PACKING_RATIO - 1 - sub_counter;
378                 counter_lane = COUNTER_LANES - 1 - counter_lane;
379             }
380 
381             // Pointer to smem digit counter
382             digit_counters[ITEM] = &temp_storage.aliasable.digit_counters[counter_lane][linear_tid][sub_counter];
383 
384             // Load thread-exclusive prefix
385             thread_prefixes[ITEM] = *digit_counters[ITEM];
386 
387             // Store inclusive prefix
388             *digit_counters[ITEM] = thread_prefixes[ITEM] + 1;
389         }
390 
391         CTA_SYNC();
392 
393         // Scan shared memory counters
394         ScanCounters();
395 
396         CTA_SYNC();
397 
398         // Extract the local ranks of each key
399         for (int ITEM = 0; ITEM < KEYS_PER_THREAD; ++ITEM)
400         {
401             // Add in thread block exclusive prefix
402             ranks[ITEM] = thread_prefixes[ITEM] + *digit_counters[ITEM];
403         }
404     }
405 
406 
407     /**
408      * \brief Rank keys.  For the lower \p RADIX_DIGITS threads, digit counts for each digit are provided for the corresponding thread.
409      */
410     template <
411         typename        UnsignedBits,
412         int             KEYS_PER_THREAD>
RankKeys(UnsignedBits (& keys)[KEYS_PER_THREAD],int (& ranks)[KEYS_PER_THREAD],int current_bit,int num_bits,int (& exclusive_digit_prefix)[BINS_TRACKED_PER_THREAD])413     __device__ __forceinline__ void RankKeys(
414         UnsignedBits    (&keys)[KEYS_PER_THREAD],           ///< [in] Keys for this tile
415         int             (&ranks)[KEYS_PER_THREAD],          ///< [out] For each key, the local rank within the tile (out parameter)
416         int             current_bit,                        ///< [in] The least-significant bit position of the current digit to extract
417         int             num_bits,                           ///< [in] The number of bits in the current digit
418         int             (&exclusive_digit_prefix)[BINS_TRACKED_PER_THREAD])            ///< [out] The exclusive prefix sum for the digits [(threadIdx.x * BINS_TRACKED_PER_THREAD) ... (threadIdx.x * BINS_TRACKED_PER_THREAD) + BINS_TRACKED_PER_THREAD - 1]
419     {
420         // Rank keys
421         RankKeys(keys, ranks, current_bit, num_bits);
422 
423         // Get the inclusive and exclusive digit totals corresponding to the calling thread.
424         #pragma unroll
425         for (int track = 0; track < BINS_TRACKED_PER_THREAD; ++track)
426         {
427             int bin_idx = (linear_tid * BINS_TRACKED_PER_THREAD) + track;
428 
429             if ((BLOCK_THREADS == RADIX_DIGITS) || (bin_idx < RADIX_DIGITS))
430             {
431                 if (IS_DESCENDING)
432                     bin_idx = RADIX_DIGITS - bin_idx - 1;
433 
434                 // Obtain ex/inclusive digit counts.  (Unfortunately these all reside in the
435                 // first counter column, resulting in unavoidable bank conflicts.)
436                 unsigned int counter_lane   = (bin_idx & (COUNTER_LANES - 1));
437                 unsigned int sub_counter    = bin_idx >> (LOG_COUNTER_LANES);
438 
439                 exclusive_digit_prefix[track] = temp_storage.aliasable.digit_counters[counter_lane][0][sub_counter];
440             }
441         }
442     }
443 };
444 
445 
446 
447 
448 
449 /**
450  * Radix-rank using match.any
451  */
452 template <
453     int                     BLOCK_DIM_X,
454     int                     RADIX_BITS,
455     bool                    IS_DESCENDING,
456     BlockScanAlgorithm      INNER_SCAN_ALGORITHM    = BLOCK_SCAN_WARP_SCANS,
457     int                     BLOCK_DIM_Y             = 1,
458     int                     BLOCK_DIM_Z             = 1,
459     int                     PTX_ARCH                = CUB_PTX_ARCH>
460 class BlockRadixRankMatch
461 {
462 private:
463 
464     /******************************************************************************
465      * Type definitions and constants
466      ******************************************************************************/
467 
468     typedef int32_t    RankT;
469     typedef int32_t    DigitCounterT;
470 
471     enum
472     {
473         // The thread block size in threads
474         BLOCK_THREADS               = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z,
475 
476         RADIX_DIGITS                = 1 << RADIX_BITS,
477 
478         LOG_WARP_THREADS            = CUB_LOG_WARP_THREADS(PTX_ARCH),
479         WARP_THREADS                = 1 << LOG_WARP_THREADS,
480         WARPS                       = (BLOCK_THREADS + WARP_THREADS - 1) / WARP_THREADS,
481 
482         PADDED_WARPS            = ((WARPS & 0x1) == 0) ?
483                                     WARPS + 1 :
484                                     WARPS,
485 
486         COUNTERS                = PADDED_WARPS * RADIX_DIGITS,
487         RAKING_SEGMENT          = (COUNTERS + BLOCK_THREADS - 1) / BLOCK_THREADS,
488         PADDED_RAKING_SEGMENT   = ((RAKING_SEGMENT & 0x1) == 0) ?
489                                     RAKING_SEGMENT + 1 :
490                                     RAKING_SEGMENT,
491     };
492 
493 public:
494 
495     enum
496     {
497         /// Number of bin-starting offsets tracked per thread
498         BINS_TRACKED_PER_THREAD = CUB_MAX(1, (RADIX_DIGITS + BLOCK_THREADS - 1) / BLOCK_THREADS),
499     };
500 
501 private:
502 
503     /// BlockScan type
504     typedef BlockScan<
505             DigitCounterT,
506             BLOCK_THREADS,
507             INNER_SCAN_ALGORITHM,
508             BLOCK_DIM_Y,
509             BLOCK_DIM_Z,
510             PTX_ARCH>
511         BlockScanT;
512 
513 
514     /// Shared memory storage layout type for BlockRadixRank
515     struct __align__(16) _TempStorage
516     {
517         typename BlockScanT::TempStorage            block_scan;
518 
519         union __align__(16) Aliasable
520         {
521             volatile DigitCounterT                  warp_digit_counters[RADIX_DIGITS][PADDED_WARPS];
522             DigitCounterT                           raking_grid[BLOCK_THREADS][PADDED_RAKING_SEGMENT];
523 
524         } aliasable;
525     };
526 
527 
528     /******************************************************************************
529      * Thread fields
530      ******************************************************************************/
531 
532     /// Shared storage reference
533     _TempStorage &temp_storage;
534 
535     /// Linear thread-id
536     unsigned int linear_tid;
537 
538 
539 
540 public:
541 
542     /// \smemstorage{BlockScan}
543     struct TempStorage : Uninitialized<_TempStorage> {};
544 
545 
546     /******************************************************************//**
547      * \name Collective constructors
548      *********************************************************************/
549     //@{
550 
551 
552     /**
553      * \brief Collective constructor using the specified memory allocation as temporary storage.
554      */
BlockRadixRankMatch(TempStorage & temp_storage)555     __device__ __forceinline__ BlockRadixRankMatch(
556         TempStorage &temp_storage)             ///< [in] Reference to memory allocation having layout type TempStorage
557     :
558         temp_storage(temp_storage.Alias()),
559         linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z))
560     {}
561 
562 
563     //@}  end member group
564     /******************************************************************//**
565      * \name Raking
566      *********************************************************************/
567     //@{
568 
569     /**
570      * \brief Rank keys.
571      */
572     template <
573         typename        UnsignedBits,
574         int             KEYS_PER_THREAD>
RankKeys(UnsignedBits (& keys)[KEYS_PER_THREAD],int (& ranks)[KEYS_PER_THREAD],int current_bit,int num_bits)575     __device__ __forceinline__ void RankKeys(
576         UnsignedBits    (&keys)[KEYS_PER_THREAD],           ///< [in] Keys for this tile
577         int             (&ranks)[KEYS_PER_THREAD],          ///< [out] For each key, the local rank within the tile
578         int             current_bit,                        ///< [in] The least-significant bit position of the current digit to extract
579         int             num_bits)                           ///< [in] The number of bits in the current digit
580     {
581         // Initialize shared digit counters
582 
583         #pragma unroll
584         for (int ITEM = 0; ITEM < PADDED_RAKING_SEGMENT; ++ITEM)
585             temp_storage.aliasable.raking_grid[linear_tid][ITEM] = 0;
586 
587         CTA_SYNC();
588 
589         // Each warp will strip-mine its section of input, one strip at a time
590 
591         volatile DigitCounterT  *digit_counters[KEYS_PER_THREAD];
592         uint32_t                warp_id         = linear_tid >> LOG_WARP_THREADS;
593         uint32_t                lane_mask_lt    = LaneMaskLt();
594 
595         #pragma unroll
596         for (int ITEM = 0; ITEM < KEYS_PER_THREAD; ++ITEM)
597         {
598             // My digit
599             uint32_t digit = BFE(keys[ITEM], current_bit, num_bits);
600 
601             if (IS_DESCENDING)
602                 digit = RADIX_DIGITS - digit - 1;
603 
604             // Mask of peers who have same digit as me
605             uint32_t peer_mask = MatchAny<RADIX_BITS>(digit);
606 
607             // Pointer to smem digit counter for this key
608             digit_counters[ITEM] = &temp_storage.aliasable.warp_digit_counters[digit][warp_id];
609 
610             // Number of occurrences in previous strips
611             DigitCounterT warp_digit_prefix = *digit_counters[ITEM];
612 
613             // Warp-sync
614             WARP_SYNC(0xFFFFFFFF);
615 
616             // Number of peers having same digit as me
617             int32_t digit_count = __popc(peer_mask);
618 
619             // Number of lower-ranked peers having same digit seen so far
620             int32_t peer_digit_prefix = __popc(peer_mask & lane_mask_lt);
621 
622             if (peer_digit_prefix == 0)
623             {
624                 // First thread for each digit updates the shared warp counter
625                 *digit_counters[ITEM] = DigitCounterT(warp_digit_prefix + digit_count);
626             }
627 
628             // Warp-sync
629             WARP_SYNC(0xFFFFFFFF);
630 
631             // Number of prior keys having same digit
632             ranks[ITEM] = warp_digit_prefix + DigitCounterT(peer_digit_prefix);
633         }
634 
635         CTA_SYNC();
636 
637         // Scan warp counters
638 
639         DigitCounterT scan_counters[PADDED_RAKING_SEGMENT];
640 
641         #pragma unroll
642         for (int ITEM = 0; ITEM < PADDED_RAKING_SEGMENT; ++ITEM)
643             scan_counters[ITEM] = temp_storage.aliasable.raking_grid[linear_tid][ITEM];
644 
645         BlockScanT(temp_storage.block_scan).ExclusiveSum(scan_counters, scan_counters);
646 
647         #pragma unroll
648         for (int ITEM = 0; ITEM < PADDED_RAKING_SEGMENT; ++ITEM)
649             temp_storage.aliasable.raking_grid[linear_tid][ITEM] = scan_counters[ITEM];
650 
651         CTA_SYNC();
652 
653         // Seed ranks with counter values from previous warps
654         #pragma unroll
655         for (int ITEM = 0; ITEM < KEYS_PER_THREAD; ++ITEM)
656             ranks[ITEM] += *digit_counters[ITEM];
657     }
658 
659 
660     /**
661      * \brief Rank keys.  For the lower \p RADIX_DIGITS threads, digit counts for each digit are provided for the corresponding thread.
662      */
663     template <
664         typename        UnsignedBits,
665         int             KEYS_PER_THREAD>
RankKeys(UnsignedBits (& keys)[KEYS_PER_THREAD],int (& ranks)[KEYS_PER_THREAD],int current_bit,int num_bits,int (& exclusive_digit_prefix)[BINS_TRACKED_PER_THREAD])666     __device__ __forceinline__ void RankKeys(
667         UnsignedBits    (&keys)[KEYS_PER_THREAD],           ///< [in] Keys for this tile
668         int             (&ranks)[KEYS_PER_THREAD],          ///< [out] For each key, the local rank within the tile (out parameter)
669         int             current_bit,                        ///< [in] The least-significant bit position of the current digit to extract
670         int             num_bits,                           ///< [in] The number of bits in the current digit
671         int             (&exclusive_digit_prefix)[BINS_TRACKED_PER_THREAD])            ///< [out] The exclusive prefix sum for the digits [(threadIdx.x * BINS_TRACKED_PER_THREAD) ... (threadIdx.x * BINS_TRACKED_PER_THREAD) + BINS_TRACKED_PER_THREAD - 1]
672     {
673         RankKeys(keys, ranks, current_bit, num_bits);
674 
675         // Get exclusive count for each digit
676         #pragma unroll
677         for (int track = 0; track < BINS_TRACKED_PER_THREAD; ++track)
678         {
679             int bin_idx = (linear_tid * BINS_TRACKED_PER_THREAD) + track;
680 
681             if ((BLOCK_THREADS == RADIX_DIGITS) || (bin_idx < RADIX_DIGITS))
682             {
683                 if (IS_DESCENDING)
684                     bin_idx = RADIX_DIGITS - bin_idx - 1;
685 
686                 exclusive_digit_prefix[track] = temp_storage.aliasable.warp_digit_counters[bin_idx][0];
687             }
688         }
689     }
690 };
691 
692 
693 }               // CUB namespace
694 CUB_NS_POSTFIX  // Optional outer namespace(s)
695 
696 
697