1 /******************************************************************************
2  * Copyright (c) 2011, Duane Merrill.  All rights reserved.
3  * Copyright (c) 2011-2018, NVIDIA CORPORATION.  All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are met:
7  *     * Redistributions of source code must retain the above copyright
8  *       notice, this list of conditions and the following disclaimer.
9  *     * Redistributions in binary form must reproduce the above copyright
10  *       notice, this list of conditions and the following disclaimer in the
11  *       documentation and/or other materials provided with the distribution.
12  *     * Neither the name of the NVIDIA CORPORATION nor the
13  *       names of its contributors may be used to endorse or promote products
14  *       derived from this software without specific prior written permission.
15  *
16  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
17  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19  * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
20  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26  *
27  ******************************************************************************/
28 
29 /**
30  * \file
31  * AgentRadixSortDownsweep implements a stateful abstraction of CUDA thread blocks for participating in device-wide radix sort downsweep .
32  */
33 
34 
35 #pragma once
36 
37 #include <stdint.h>
38 
39 #include "../thread/thread_load.cuh"
40 #include "../block/block_load.cuh"
41 #include "../block/block_store.cuh"
42 #include "../block/block_radix_rank.cuh"
43 #include "../block/block_exchange.cuh"
44 #include "../util_type.cuh"
45 #include "../iterator/cache_modified_input_iterator.cuh"
46 #include "../util_namespace.cuh"
47 
48 /// Optional outer namespace(s)
49 CUB_NS_PREFIX
50 
51 /// CUB namespace
52 namespace cub {
53 
54 
55 /******************************************************************************
56  * Tuning policy types
57  ******************************************************************************/
58 
59 /**
60  * Radix ranking algorithm
61  */
62 enum RadixRankAlgorithm
63 {
64     RADIX_RANK_BASIC,
65     RADIX_RANK_MEMOIZE,
66     RADIX_RANK_MATCH
67 };
68 
69 /**
70  * Parameterizable tuning policy type for AgentRadixSortDownsweep
71  */
72 template <
73     int                         _BLOCK_THREADS,         ///< Threads per thread block
74     int                         _ITEMS_PER_THREAD,      ///< Items per thread (per tile of input)
75     BlockLoadAlgorithm          _LOAD_ALGORITHM,        ///< The BlockLoad algorithm to use
76     CacheLoadModifier           _LOAD_MODIFIER,         ///< Cache load modifier for reading keys (and values)
77     RadixRankAlgorithm          _RANK_ALGORITHM,        ///< The radix ranking algorithm to use
78     BlockScanAlgorithm          _SCAN_ALGORITHM,        ///< The block scan algorithm to use
79     int                         _RADIX_BITS>            ///< The number of radix bits, i.e., log2(bins)
80 struct AgentRadixSortDownsweepPolicy
81 {
82     enum
83     {
84         BLOCK_THREADS           = _BLOCK_THREADS,           ///< Threads per thread block
85         ITEMS_PER_THREAD        = _ITEMS_PER_THREAD,        ///< Items per thread (per tile of input)
86         RADIX_BITS              = _RADIX_BITS,              ///< The number of radix bits, i.e., log2(bins)
87     };
88 
89     static const BlockLoadAlgorithm  LOAD_ALGORITHM     = _LOAD_ALGORITHM;    ///< The BlockLoad algorithm to use
90     static const CacheLoadModifier   LOAD_MODIFIER      = _LOAD_MODIFIER;     ///< Cache load modifier for reading keys (and values)
91     static const RadixRankAlgorithm  RANK_ALGORITHM     = _RANK_ALGORITHM;    ///< The radix ranking algorithm to use
92     static const BlockScanAlgorithm  SCAN_ALGORITHM     = _SCAN_ALGORITHM;    ///< The BlockScan algorithm to use
93 };
94 
95 
96 /******************************************************************************
97  * Thread block abstractions
98  ******************************************************************************/
99 
100 
101 
102 
103 
104 /**
105  * \brief AgentRadixSortDownsweep implements a stateful abstraction of CUDA thread blocks for participating in device-wide radix sort downsweep .
106  */
107 template <
108     typename AgentRadixSortDownsweepPolicy,     ///< Parameterized AgentRadixSortDownsweepPolicy tuning policy type
109     bool     IS_DESCENDING,                     ///< Whether or not the sorted-order is high-to-low
110     typename KeyT,                              ///< KeyT type
111     typename ValueT,                            ///< ValueT type
112     typename OffsetT>                           ///< Signed integer type for global offsets
113 struct AgentRadixSortDownsweep
114 {
115     //---------------------------------------------------------------------
116     // Type definitions and constants
117     //---------------------------------------------------------------------
118 
119     // Appropriate unsigned-bits representation of KeyT
120     typedef typename Traits<KeyT>::UnsignedBits UnsignedBits;
121 
122     static const UnsignedBits           LOWEST_KEY  = Traits<KeyT>::LOWEST_KEY;
123     static const UnsignedBits           MAX_KEY     = Traits<KeyT>::MAX_KEY;
124 
125     static const BlockLoadAlgorithm     LOAD_ALGORITHM  = AgentRadixSortDownsweepPolicy::LOAD_ALGORITHM;
126     static const CacheLoadModifier      LOAD_MODIFIER   = AgentRadixSortDownsweepPolicy::LOAD_MODIFIER;
127     static const RadixRankAlgorithm     RANK_ALGORITHM  = AgentRadixSortDownsweepPolicy::RANK_ALGORITHM;
128     static const BlockScanAlgorithm     SCAN_ALGORITHM  = AgentRadixSortDownsweepPolicy::SCAN_ALGORITHM;
129 
130     enum
131     {
132         BLOCK_THREADS           = AgentRadixSortDownsweepPolicy::BLOCK_THREADS,
133         ITEMS_PER_THREAD        = AgentRadixSortDownsweepPolicy::ITEMS_PER_THREAD,
134         RADIX_BITS              = AgentRadixSortDownsweepPolicy::RADIX_BITS,
135         TILE_ITEMS              = BLOCK_THREADS * ITEMS_PER_THREAD,
136 
137         RADIX_DIGITS            = 1 << RADIX_BITS,
138         KEYS_ONLY               = Equals<ValueT, NullType>::VALUE,
139     };
140 
141     // Input iterator wrapper type (for applying cache modifier)s
142     typedef CacheModifiedInputIterator<LOAD_MODIFIER, UnsignedBits, OffsetT>    KeysItr;
143     typedef CacheModifiedInputIterator<LOAD_MODIFIER, ValueT, OffsetT>          ValuesItr;
144 
145     // Radix ranking type to use
146     typedef typename If<(RANK_ALGORITHM == RADIX_RANK_BASIC),
147             BlockRadixRank<BLOCK_THREADS, RADIX_BITS, IS_DESCENDING, false, SCAN_ALGORITHM>,
148             typename If<(RANK_ALGORITHM == RADIX_RANK_MEMOIZE),
149                 BlockRadixRank<BLOCK_THREADS, RADIX_BITS, IS_DESCENDING, true, SCAN_ALGORITHM>,
150                 BlockRadixRankMatch<BLOCK_THREADS, RADIX_BITS, IS_DESCENDING, SCAN_ALGORITHM>
151             >::Type
152         >::Type BlockRadixRankT;
153 
154     enum
155     {
156         /// Number of bin-starting offsets tracked per thread
157         BINS_TRACKED_PER_THREAD = BlockRadixRankT::BINS_TRACKED_PER_THREAD
158     };
159 
160     // BlockLoad type (keys)
161     typedef BlockLoad<
162         UnsignedBits,
163         BLOCK_THREADS,
164         ITEMS_PER_THREAD,
165         LOAD_ALGORITHM> BlockLoadKeysT;
166 
167     // BlockLoad type (values)
168     typedef BlockLoad<
169         ValueT,
170         BLOCK_THREADS,
171         ITEMS_PER_THREAD,
172         LOAD_ALGORITHM> BlockLoadValuesT;
173 
174     // Value exchange array type
175     typedef ValueT ValueExchangeT[TILE_ITEMS];
176 
177     /**
178      * Shared memory storage layout
179      */
180     union __align__(16) _TempStorage
181     {
182         typename BlockLoadKeysT::TempStorage    load_keys;
183         typename BlockLoadValuesT::TempStorage  load_values;
184         typename BlockRadixRankT::TempStorage   radix_rank;
185 
186         struct
187         {
188             UnsignedBits                        exchange_keys[TILE_ITEMS];
189             OffsetT                             relative_bin_offsets[RADIX_DIGITS];
190         };
191 
192         Uninitialized<ValueExchangeT>           exchange_values;
193 
194         OffsetT                                 exclusive_digit_prefix[RADIX_DIGITS];
195     };
196 
197 
198     /// Alias wrapper allowing storage to be unioned
199     struct TempStorage : Uninitialized<_TempStorage> {};
200 
201 
202     //---------------------------------------------------------------------
203     // Thread fields
204     //---------------------------------------------------------------------
205 
206     // Shared storage for this CTA
207     _TempStorage    &temp_storage;
208 
209     // Input and output device pointers
210     KeysItr         d_keys_in;
211     ValuesItr       d_values_in;
212     UnsignedBits    *d_keys_out;
213     ValueT          *d_values_out;
214 
215     // The global scatter base offset for each digit (valid in the first RADIX_DIGITS threads)
216     OffsetT         bin_offset[BINS_TRACKED_PER_THREAD];
217 
218     // The least-significant bit position of the current digit to extract
219     int             current_bit;
220 
221     // Number of bits in current digit
222     int             num_bits;
223 
224     // Whether to short-cirucit
225     int             short_circuit;
226 
227     //---------------------------------------------------------------------
228     // Utility methods
229     //---------------------------------------------------------------------
230 
231 
232     /**
233      * Scatter ranked keys through shared memory, then to device-accessible memory
234      */
235     template <bool FULL_TILE>
ScatterKeyscub::AgentRadixSortDownsweep236     __device__ __forceinline__ void ScatterKeys(
237         UnsignedBits    (&twiddled_keys)[ITEMS_PER_THREAD],
238         OffsetT         (&relative_bin_offsets)[ITEMS_PER_THREAD],
239         int             (&ranks)[ITEMS_PER_THREAD],
240         OffsetT         valid_items)
241     {
242         #pragma unroll
243         for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
244         {
245             temp_storage.exchange_keys[ranks[ITEM]] = twiddled_keys[ITEM];
246         }
247 
248         CTA_SYNC();
249 
250         #pragma unroll
251         for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
252         {
253             UnsignedBits key            = temp_storage.exchange_keys[threadIdx.x + (ITEM * BLOCK_THREADS)];
254             UnsignedBits digit          = BFE(key, current_bit, num_bits);
255             relative_bin_offsets[ITEM]  = temp_storage.relative_bin_offsets[digit];
256 
257             // Un-twiddle
258             key = Traits<KeyT>::TwiddleOut(key);
259 
260             if (FULL_TILE ||
261                 (static_cast<OffsetT>(threadIdx.x + (ITEM * BLOCK_THREADS)) < valid_items))
262             {
263                 d_keys_out[relative_bin_offsets[ITEM] + threadIdx.x + (ITEM * BLOCK_THREADS)] = key;
264             }
265         }
266     }
267 
268 
269     /**
270      * Scatter ranked values through shared memory, then to device-accessible memory
271      */
272     template <bool FULL_TILE>
ScatterValuescub::AgentRadixSortDownsweep273     __device__ __forceinline__ void ScatterValues(
274         ValueT      (&values)[ITEMS_PER_THREAD],
275         OffsetT     (&relative_bin_offsets)[ITEMS_PER_THREAD],
276         int         (&ranks)[ITEMS_PER_THREAD],
277         OffsetT     valid_items)
278     {
279         CTA_SYNC();
280 
281         ValueExchangeT &exchange_values = temp_storage.exchange_values.Alias();
282 
283         #pragma unroll
284         for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
285         {
286             exchange_values[ranks[ITEM]] = values[ITEM];
287         }
288 
289         CTA_SYNC();
290 
291         #pragma unroll
292         for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
293         {
294             ValueT value = exchange_values[threadIdx.x + (ITEM * BLOCK_THREADS)];
295 
296             if (FULL_TILE ||
297                 (static_cast<OffsetT>(threadIdx.x + (ITEM * BLOCK_THREADS)) < valid_items))
298             {
299                 d_values_out[relative_bin_offsets[ITEM] + threadIdx.x + (ITEM * BLOCK_THREADS)] = value;
300             }
301         }
302     }
303 
304     /**
305      * Load a tile of keys (specialized for full tile, any ranking algorithm)
306      */
307     template <int _RANK_ALGORITHM>
LoadKeyscub::AgentRadixSortDownsweep308     __device__ __forceinline__ void LoadKeys(
309         UnsignedBits                (&keys)[ITEMS_PER_THREAD],
310         OffsetT                     block_offset,
311         OffsetT                     valid_items,
312         UnsignedBits                oob_item,
313         Int2Type<true>              is_full_tile,
314         Int2Type<_RANK_ALGORITHM>   rank_algorithm)
315     {
316         BlockLoadKeysT(temp_storage.load_keys).Load(
317             d_keys_in + block_offset, keys);
318 
319         CTA_SYNC();
320     }
321 
322 
323     /**
324      * Load a tile of keys (specialized for partial tile, any ranking algorithm)
325      */
326     template <int _RANK_ALGORITHM>
LoadKeyscub::AgentRadixSortDownsweep327     __device__ __forceinline__ void LoadKeys(
328         UnsignedBits                (&keys)[ITEMS_PER_THREAD],
329         OffsetT                     block_offset,
330         OffsetT                     valid_items,
331         UnsignedBits                oob_item,
332         Int2Type<false>             is_full_tile,
333         Int2Type<_RANK_ALGORITHM>   rank_algorithm)
334     {
335         // Register pressure work-around: moving valid_items through shfl prevents compiler
336         // from reusing guards/addressing from prior guarded loads
337         valid_items = ShuffleIndex<CUB_PTX_WARP_THREADS>(valid_items, 0, 0xffffffff);
338 
339         BlockLoadKeysT(temp_storage.load_keys).Load(
340             d_keys_in + block_offset, keys, valid_items, oob_item);
341 
342         CTA_SYNC();
343     }
344 
345 
346     /**
347      * Load a tile of keys (specialized for full tile, match ranking algorithm)
348      */
LoadKeyscub::AgentRadixSortDownsweep349     __device__ __forceinline__ void LoadKeys(
350         UnsignedBits                (&keys)[ITEMS_PER_THREAD],
351         OffsetT                     block_offset,
352         OffsetT                     valid_items,
353         UnsignedBits                oob_item,
354         Int2Type<true>              is_full_tile,
355         Int2Type<RADIX_RANK_MATCH>  rank_algorithm)
356     {
357         LoadDirectWarpStriped(threadIdx.x, d_keys_in + block_offset, keys);
358     }
359 
360 
361     /**
362      * Load a tile of keys (specialized for partial tile, match ranking algorithm)
363      */
LoadKeyscub::AgentRadixSortDownsweep364     __device__ __forceinline__ void LoadKeys(
365         UnsignedBits                (&keys)[ITEMS_PER_THREAD],
366         OffsetT                     block_offset,
367         OffsetT                     valid_items,
368         UnsignedBits                oob_item,
369         Int2Type<false>             is_full_tile,
370         Int2Type<RADIX_RANK_MATCH>  rank_algorithm)
371     {
372         // Register pressure work-around: moving valid_items through shfl prevents compiler
373         // from reusing guards/addressing from prior guarded loads
374         valid_items = ShuffleIndex<CUB_PTX_WARP_THREADS>(valid_items, 0, 0xffffffff);
375 
376         LoadDirectWarpStriped(threadIdx.x, d_keys_in + block_offset, keys, valid_items, oob_item);
377     }
378 
379 
380     /**
381      * Load a tile of values (specialized for full tile, any ranking algorithm)
382      */
383     template <int _RANK_ALGORITHM>
LoadValuescub::AgentRadixSortDownsweep384     __device__ __forceinline__ void LoadValues(
385         ValueT                      (&values)[ITEMS_PER_THREAD],
386         OffsetT                     block_offset,
387         OffsetT                     valid_items,
388         Int2Type<true>              is_full_tile,
389         Int2Type<_RANK_ALGORITHM>   rank_algorithm)
390     {
391         BlockLoadValuesT(temp_storage.load_values).Load(
392             d_values_in + block_offset, values);
393 
394         CTA_SYNC();
395     }
396 
397 
398     /**
399      * Load a tile of values (specialized for partial tile, any ranking algorithm)
400      */
401     template <int _RANK_ALGORITHM>
LoadValuescub::AgentRadixSortDownsweep402     __device__ __forceinline__ void LoadValues(
403         ValueT                      (&values)[ITEMS_PER_THREAD],
404         OffsetT                     block_offset,
405         OffsetT                     valid_items,
406         Int2Type<false>             is_full_tile,
407         Int2Type<_RANK_ALGORITHM>   rank_algorithm)
408     {
409         // Register pressure work-around: moving valid_items through shfl prevents compiler
410         // from reusing guards/addressing from prior guarded loads
411         valid_items = ShuffleIndex<CUB_PTX_WARP_THREADS>(valid_items, 0, 0xffffffff);
412 
413         BlockLoadValuesT(temp_storage.load_values).Load(
414             d_values_in + block_offset, values, valid_items);
415 
416         CTA_SYNC();
417     }
418 
419 
420     /**
421      * Load a tile of items (specialized for full tile, match ranking algorithm)
422      */
LoadValuescub::AgentRadixSortDownsweep423     __device__ __forceinline__ void LoadValues(
424         ValueT                      (&values)[ITEMS_PER_THREAD],
425         OffsetT                     block_offset,
426         OffsetT                     valid_items,
427         Int2Type<true>              is_full_tile,
428         Int2Type<RADIX_RANK_MATCH>  rank_algorithm)
429     {
430         LoadDirectWarpStriped(threadIdx.x, d_values_in + block_offset, values);
431     }
432 
433 
434     /**
435      * Load a tile of items (specialized for partial tile, match ranking algorithm)
436      */
LoadValuescub::AgentRadixSortDownsweep437     __device__ __forceinline__ void LoadValues(
438         ValueT                      (&values)[ITEMS_PER_THREAD],
439         OffsetT                     block_offset,
440         OffsetT                     valid_items,
441         Int2Type<false>             is_full_tile,
442         Int2Type<RADIX_RANK_MATCH>  rank_algorithm)
443     {
444         // Register pressure work-around: moving valid_items through shfl prevents compiler
445         // from reusing guards/addressing from prior guarded loads
446         valid_items = ShuffleIndex<CUB_PTX_WARP_THREADS>(valid_items, 0, 0xffffffff);
447 
448         LoadDirectWarpStriped(threadIdx.x, d_values_in + block_offset, values, valid_items);
449     }
450 
451 
452     /**
453      * Truck along associated values
454      */
455     template <bool FULL_TILE>
GatherScatterValuescub::AgentRadixSortDownsweep456     __device__ __forceinline__ void GatherScatterValues(
457         OffsetT         (&relative_bin_offsets)[ITEMS_PER_THREAD],
458         int             (&ranks)[ITEMS_PER_THREAD],
459         OffsetT         block_offset,
460         OffsetT         valid_items,
461         Int2Type<false> /*is_keys_only*/)
462     {
463         ValueT values[ITEMS_PER_THREAD];
464 
465         CTA_SYNC();
466 
467         LoadValues(
468             values,
469             block_offset,
470             valid_items,
471             Int2Type<FULL_TILE>(),
472             Int2Type<RANK_ALGORITHM>());
473 
474         ScatterValues<FULL_TILE>(
475             values,
476             relative_bin_offsets,
477             ranks,
478             valid_items);
479     }
480 
481 
482     /**
483      * Truck along associated values (specialized for key-only sorting)
484      */
485     template <bool FULL_TILE>
GatherScatterValuescub::AgentRadixSortDownsweep486     __device__ __forceinline__ void GatherScatterValues(
487         OffsetT         (&/*relative_bin_offsets*/)[ITEMS_PER_THREAD],
488         int             (&/*ranks*/)[ITEMS_PER_THREAD],
489         OffsetT         /*block_offset*/,
490         OffsetT         /*valid_items*/,
491         Int2Type<true>  /*is_keys_only*/)
492     {}
493 
494 
495     /**
496      * Process tile
497      */
498     template <bool FULL_TILE>
ProcessTilecub::AgentRadixSortDownsweep499     __device__ __forceinline__ void ProcessTile(
500         OffsetT block_offset,
501         const OffsetT &valid_items = TILE_ITEMS)
502     {
503         UnsignedBits    keys[ITEMS_PER_THREAD];
504         int             ranks[ITEMS_PER_THREAD];
505         OffsetT         relative_bin_offsets[ITEMS_PER_THREAD];
506 
507         // Assign default (min/max) value to all keys
508         UnsignedBits default_key = (IS_DESCENDING) ? LOWEST_KEY : MAX_KEY;
509 
510         // Load tile of keys
511         LoadKeys(
512             keys,
513             block_offset,
514             valid_items,
515             default_key,
516             Int2Type<FULL_TILE>(),
517             Int2Type<RANK_ALGORITHM>());
518 
519         // Twiddle key bits if necessary
520         #pragma unroll
521         for (int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++)
522         {
523             keys[KEY] = Traits<KeyT>::TwiddleIn(keys[KEY]);
524         }
525 
526         // Rank the twiddled keys
527         int exclusive_digit_prefix[BINS_TRACKED_PER_THREAD];
528         BlockRadixRankT(temp_storage.radix_rank).RankKeys(
529             keys,
530             ranks,
531             current_bit,
532             num_bits,
533             exclusive_digit_prefix);
534 
535         CTA_SYNC();
536 
537         // Share exclusive digit prefix
538         #pragma unroll
539         for (int track = 0; track < BINS_TRACKED_PER_THREAD; ++track)
540         {
541             int bin_idx = (threadIdx.x * BINS_TRACKED_PER_THREAD) + track;
542             if ((BLOCK_THREADS == RADIX_DIGITS) || (bin_idx < RADIX_DIGITS))
543             {
544                 // Store exclusive prefix
545                 temp_storage.exclusive_digit_prefix[bin_idx] =
546                     exclusive_digit_prefix[track];
547             }
548         }
549 
550         CTA_SYNC();
551 
552         // Get inclusive digit prefix
553         int inclusive_digit_prefix[BINS_TRACKED_PER_THREAD];
554 
555         #pragma unroll
556         for (int track = 0; track < BINS_TRACKED_PER_THREAD; ++track)
557         {
558             int bin_idx = (threadIdx.x * BINS_TRACKED_PER_THREAD) + track;
559             if ((BLOCK_THREADS == RADIX_DIGITS) || (bin_idx < RADIX_DIGITS))
560             {
561                 if (IS_DESCENDING)
562                 {
563                     // Get inclusive digit prefix from exclusive prefix (higher bins come first)
564                     inclusive_digit_prefix[track] = (bin_idx == 0) ?
565                         (BLOCK_THREADS * ITEMS_PER_THREAD) :
566                         temp_storage.exclusive_digit_prefix[bin_idx - 1];
567                 }
568                 else
569                 {
570                     // Get inclusive digit prefix from exclusive prefix (lower bins come first)
571                     inclusive_digit_prefix[track] = (bin_idx == RADIX_DIGITS - 1) ?
572                         (BLOCK_THREADS * ITEMS_PER_THREAD) :
573                         temp_storage.exclusive_digit_prefix[bin_idx + 1];
574                 }
575             }
576         }
577 
578         CTA_SYNC();
579 
580         // Update global scatter base offsets for each digit
581         #pragma unroll
582         for (int track = 0; track < BINS_TRACKED_PER_THREAD; ++track)
583         {
584             int bin_idx = (threadIdx.x * BINS_TRACKED_PER_THREAD) + track;
585             if ((BLOCK_THREADS == RADIX_DIGITS) || (bin_idx < RADIX_DIGITS))
586             {
587                 bin_offset[track] -= exclusive_digit_prefix[track];
588                 temp_storage.relative_bin_offsets[bin_idx] = bin_offset[track];
589                 bin_offset[track] += inclusive_digit_prefix[track];
590             }
591         }
592 
593         CTA_SYNC();
594 
595         // Scatter keys
596         ScatterKeys<FULL_TILE>(keys, relative_bin_offsets, ranks, valid_items);
597 
598         // Gather/scatter values
599         GatherScatterValues<FULL_TILE>(relative_bin_offsets , ranks, block_offset, valid_items, Int2Type<KEYS_ONLY>());
600     }
601 
602     //---------------------------------------------------------------------
603     // Copy shortcut
604     //---------------------------------------------------------------------
605 
606     /**
607      * Copy tiles within the range of input
608      */
609     template <
610         typename InputIteratorT,
611         typename T>
Copycub::AgentRadixSortDownsweep612     __device__ __forceinline__ void Copy(
613         InputIteratorT  d_in,
614         T               *d_out,
615         OffsetT         block_offset,
616         OffsetT         block_end)
617     {
618         // Simply copy the input
619         while (block_offset + TILE_ITEMS <= block_end)
620         {
621             T items[ITEMS_PER_THREAD];
622 
623             LoadDirectStriped<BLOCK_THREADS>(threadIdx.x, d_in + block_offset, items);
624             CTA_SYNC();
625             StoreDirectStriped<BLOCK_THREADS>(threadIdx.x, d_out + block_offset, items);
626 
627             block_offset += TILE_ITEMS;
628         }
629 
630         // Clean up last partial tile with guarded-I/O
631         if (block_offset < block_end)
632         {
633             OffsetT valid_items = block_end - block_offset;
634 
635             T items[ITEMS_PER_THREAD];
636 
637             LoadDirectStriped<BLOCK_THREADS>(threadIdx.x, d_in + block_offset, items, valid_items);
638             CTA_SYNC();
639             StoreDirectStriped<BLOCK_THREADS>(threadIdx.x, d_out + block_offset, items, valid_items);
640         }
641     }
642 
643 
644     /**
645      * Copy tiles within the range of input (specialized for NullType)
646      */
647     template <typename InputIteratorT>
Copycub::AgentRadixSortDownsweep648     __device__ __forceinline__ void Copy(
649         InputIteratorT  /*d_in*/,
650         NullType        * /*d_out*/,
651         OffsetT         /*block_offset*/,
652         OffsetT         /*block_end*/)
653     {}
654 
655 
656     //---------------------------------------------------------------------
657     // Interface
658     //---------------------------------------------------------------------
659 
660     /**
661      * Constructor
662      */
AgentRadixSortDownsweepcub::AgentRadixSortDownsweep663     __device__ __forceinline__ AgentRadixSortDownsweep(
664         TempStorage     &temp_storage,
665         OffsetT         (&bin_offset)[BINS_TRACKED_PER_THREAD],
666         OffsetT         num_items,
667         const KeyT      *d_keys_in,
668         KeyT            *d_keys_out,
669         const ValueT    *d_values_in,
670         ValueT          *d_values_out,
671         int             current_bit,
672         int             num_bits)
673     :
674         temp_storage(temp_storage.Alias()),
675         d_keys_in(reinterpret_cast<const UnsignedBits*>(d_keys_in)),
676         d_values_in(d_values_in),
677         d_keys_out(reinterpret_cast<UnsignedBits*>(d_keys_out)),
678         d_values_out(d_values_out),
679         current_bit(current_bit),
680         num_bits(num_bits),
681         short_circuit(1)
682     {
683         #pragma unroll
684         for (int track = 0; track < BINS_TRACKED_PER_THREAD; ++track)
685         {
686             this->bin_offset[track] = bin_offset[track];
687 
688             int bin_idx = (threadIdx.x * BINS_TRACKED_PER_THREAD) + track;
689             if ((BLOCK_THREADS == RADIX_DIGITS) || (bin_idx < RADIX_DIGITS))
690             {
691                 // Short circuit if the histogram has only bin counts of only zeros or problem-size
692                 short_circuit = short_circuit && ((bin_offset[track] == 0) || (bin_offset[track] == num_items));
693             }
694         }
695 
696         short_circuit = CTA_SYNC_AND(short_circuit);
697     }
698 
699 
700     /**
701      * Constructor
702      */
AgentRadixSortDownsweepcub::AgentRadixSortDownsweep703     __device__ __forceinline__ AgentRadixSortDownsweep(
704         TempStorage     &temp_storage,
705         OffsetT         num_items,
706         OffsetT         *d_spine,
707         const KeyT      *d_keys_in,
708         KeyT            *d_keys_out,
709         const ValueT    *d_values_in,
710         ValueT          *d_values_out,
711         int             current_bit,
712         int             num_bits)
713     :
714         temp_storage(temp_storage.Alias()),
715         d_keys_in(reinterpret_cast<const UnsignedBits*>(d_keys_in)),
716         d_values_in(d_values_in),
717         d_keys_out(reinterpret_cast<UnsignedBits*>(d_keys_out)),
718         d_values_out(d_values_out),
719         current_bit(current_bit),
720         num_bits(num_bits),
721         short_circuit(1)
722     {
723         #pragma unroll
724         for (int track = 0; track < BINS_TRACKED_PER_THREAD; ++track)
725         {
726             int bin_idx = (threadIdx.x * BINS_TRACKED_PER_THREAD) + track;
727 
728             // Load digit bin offsets (each of the first RADIX_DIGITS threads will load an offset for that digit)
729             if ((BLOCK_THREADS == RADIX_DIGITS) || (bin_idx < RADIX_DIGITS))
730             {
731                 if (IS_DESCENDING)
732                     bin_idx = RADIX_DIGITS - bin_idx - 1;
733 
734                 // Short circuit if the first block's histogram has only bin counts of only zeros or problem-size
735                 OffsetT first_block_bin_offset = d_spine[gridDim.x * bin_idx];
736                 short_circuit = short_circuit && ((first_block_bin_offset == 0) || (first_block_bin_offset == num_items));
737 
738                 // Load my block's bin offset for my bin
739                 bin_offset[track] = d_spine[(gridDim.x * bin_idx) + blockIdx.x];
740             }
741         }
742 
743         short_circuit = CTA_SYNC_AND(short_circuit);
744     }
745 
746 
747     /**
748      * Distribute keys from a segment of input tiles.
749      */
ProcessRegioncub::AgentRadixSortDownsweep750     __device__ __forceinline__ void ProcessRegion(
751         OffsetT   block_offset,
752         OffsetT   block_end)
753     {
754         if (short_circuit)
755         {
756             // Copy keys
757             Copy(d_keys_in, d_keys_out, block_offset, block_end);
758 
759             // Copy values
760             Copy(d_values_in, d_values_out, block_offset, block_end);
761         }
762         else
763         {
764             // Process full tiles of tile_items
765             #pragma unroll 1
766             while (block_offset + TILE_ITEMS <= block_end)
767             {
768                 ProcessTile<true>(block_offset);
769                 block_offset += TILE_ITEMS;
770 
771                 CTA_SYNC();
772             }
773 
774             // Clean up last partial tile with guarded-I/O
775             if (block_offset < block_end)
776             {
777                 ProcessTile<false>(block_offset, block_end - block_offset);
778             }
779 
780         }
781     }
782 
783 };
784 
785 
786 
787 }               // CUB namespace
788 CUB_NS_POSTFIX  // Optional outer namespace(s)
789 
790