1 /******************************************************************************
2  * Copyright (c) 2011, Duane Merrill.  All rights reserved.
3  * Copyright (c) 2011-2018, NVIDIA CORPORATION.  All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are met:
7  *     * Redistributions of source code must retain the above copyright
8  *       notice, this list of conditions and the following disclaimer.
9  *     * Redistributions in binary form must reproduce the above copyright
10  *       notice, this list of conditions and the following disclaimer in the
11  *       documentation and/or other materials provided with the distribution.
12  *     * Neither the name of the NVIDIA CORPORATION nor the
13  *       names of its contributors may be used to endorse or promote products
14  *       derived from this software without specific prior written permission.
15  *
16  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
17  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19  * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
20  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26  *
27  ******************************************************************************/
28 
29 /**
30  * \file
31  * AgentRadixSortUpsweep implements a stateful abstraction of CUDA thread blocks for participating in device-wide radix sort upsweep .
32  */
33 
34 #pragma once
35 
36 #include "../thread/thread_reduce.cuh"
37 #include "../thread/thread_load.cuh"
38 #include "../warp/warp_reduce.cuh"
39 #include "../block/block_load.cuh"
40 #include "../util_type.cuh"
41 #include "../iterator/cache_modified_input_iterator.cuh"
42 #include "../util_namespace.cuh"
43 
44 /// Optional outer namespace(s)
45 CUB_NS_PREFIX
46 
47 /// CUB namespace
48 namespace cub {
49 
50 /******************************************************************************
51  * Tuning policy types
52  ******************************************************************************/
53 
54 /**
55  * Parameterizable tuning policy type for AgentRadixSortUpsweep
56  */
57 template <
58     int                 _BLOCK_THREADS,     ///< Threads per thread block
59     int                 _ITEMS_PER_THREAD,  ///< Items per thread (per tile of input)
60     CacheLoadModifier   _LOAD_MODIFIER,     ///< Cache load modifier for reading keys
61     int                 _RADIX_BITS>        ///< The number of radix bits, i.e., log2(bins)
62 struct AgentRadixSortUpsweepPolicy
63 {
64     enum
65     {
66         BLOCK_THREADS       = _BLOCK_THREADS,       ///< Threads per thread block
67         ITEMS_PER_THREAD    = _ITEMS_PER_THREAD,    ///< Items per thread (per tile of input)
68         RADIX_BITS          = _RADIX_BITS,          ///< The number of radix bits, i.e., log2(bins)
69     };
70 
71     static const CacheLoadModifier LOAD_MODIFIER = _LOAD_MODIFIER;      ///< Cache load modifier for reading keys
72 };
73 
74 
75 /******************************************************************************
76  * Thread block abstractions
77  ******************************************************************************/
78 
79 /**
80  * \brief AgentRadixSortUpsweep implements a stateful abstraction of CUDA thread blocks for participating in device-wide radix sort upsweep .
81  */
82 template <
83     typename AgentRadixSortUpsweepPolicy,   ///< Parameterized AgentRadixSortUpsweepPolicy tuning policy type
84     typename KeyT,                          ///< KeyT type
85     typename OffsetT>                       ///< Signed integer type for global offsets
86 struct AgentRadixSortUpsweep
87 {
88 
89     //---------------------------------------------------------------------
90     // Type definitions and constants
91     //---------------------------------------------------------------------
92 
93     typedef typename Traits<KeyT>::UnsignedBits UnsignedBits;
94 
95     // Integer type for digit counters (to be packed into words of PackedCounters)
96     typedef unsigned char DigitCounter;
97 
98     // Integer type for packing DigitCounters into columns of shared memory banks
99     typedef unsigned int PackedCounter;
100 
101     static const CacheLoadModifier LOAD_MODIFIER = AgentRadixSortUpsweepPolicy::LOAD_MODIFIER;
102 
103     enum
104     {
105         RADIX_BITS              = AgentRadixSortUpsweepPolicy::RADIX_BITS,
106         BLOCK_THREADS           = AgentRadixSortUpsweepPolicy::BLOCK_THREADS,
107         KEYS_PER_THREAD         = AgentRadixSortUpsweepPolicy::ITEMS_PER_THREAD,
108 
109         RADIX_DIGITS            = 1 << RADIX_BITS,
110 
111         LOG_WARP_THREADS        = CUB_PTX_LOG_WARP_THREADS,
112         WARP_THREADS            = 1 << LOG_WARP_THREADS,
113         WARPS                   = (BLOCK_THREADS + WARP_THREADS - 1) / WARP_THREADS,
114 
115         TILE_ITEMS              = BLOCK_THREADS * KEYS_PER_THREAD,
116 
117         BYTES_PER_COUNTER       = sizeof(DigitCounter),
118         LOG_BYTES_PER_COUNTER   = Log2<BYTES_PER_COUNTER>::VALUE,
119 
120         PACKING_RATIO           = sizeof(PackedCounter) / sizeof(DigitCounter),
121         LOG_PACKING_RATIO       = Log2<PACKING_RATIO>::VALUE,
122 
123         LOG_COUNTER_LANES       = CUB_MAX(0, RADIX_BITS - LOG_PACKING_RATIO),
124         COUNTER_LANES           = 1 << LOG_COUNTER_LANES,
125 
126         // To prevent counter overflow, we must periodically unpack and aggregate the
127         // digit counters back into registers.  Each counter lane is assigned to a
128         // warp for aggregation.
129 
130         LANES_PER_WARP          = CUB_MAX(1, (COUNTER_LANES + WARPS - 1) / WARPS),
131 
132         // Unroll tiles in batches without risk of counter overflow
133         UNROLL_COUNT            = CUB_MIN(64, 255 / KEYS_PER_THREAD),
134         UNROLLED_ELEMENTS       = UNROLL_COUNT * TILE_ITEMS,
135     };
136 
137 
138     // Input iterator wrapper type (for applying cache modifier)s
139     typedef CacheModifiedInputIterator<LOAD_MODIFIER, UnsignedBits, OffsetT> KeysItr;
140 
141     /**
142      * Shared memory storage layout
143      */
144     union __align__(16) _TempStorage
145     {
146         DigitCounter    thread_counters[COUNTER_LANES][BLOCK_THREADS][PACKING_RATIO];
147         PackedCounter   packed_thread_counters[COUNTER_LANES][BLOCK_THREADS];
148         OffsetT         block_counters[WARP_THREADS][RADIX_DIGITS];
149     };
150 
151 
152     /// Alias wrapper allowing storage to be unioned
153     struct TempStorage : Uninitialized<_TempStorage> {};
154 
155 
156     //---------------------------------------------------------------------
157     // Thread fields (aggregate state bundle)
158     //---------------------------------------------------------------------
159 
160     // Shared storage for this CTA
161     _TempStorage    &temp_storage;
162 
163     // Thread-local counters for periodically aggregating composite-counter lanes
164     OffsetT         local_counts[LANES_PER_WARP][PACKING_RATIO];
165 
166     // Input and output device pointers
167     KeysItr         d_keys_in;
168 
169     // The least-significant bit position of the current digit to extract
170     int             current_bit;
171 
172     // Number of bits in current digit
173     int             num_bits;
174 
175 
176 
177     //---------------------------------------------------------------------
178     // Helper structure for templated iteration
179     //---------------------------------------------------------------------
180 
181     // Iterate
182     template <int COUNT, int MAX>
183     struct Iterate
184     {
185         // BucketKeys
BucketKeyscub::AgentRadixSortUpsweep::Iterate186         static __device__ __forceinline__ void BucketKeys(
187             AgentRadixSortUpsweep       &cta,
188             UnsignedBits                keys[KEYS_PER_THREAD])
189         {
190             cta.Bucket(keys[COUNT]);
191 
192             // Next
193             Iterate<COUNT + 1, MAX>::BucketKeys(cta, keys);
194         }
195     };
196 
197     // Terminate
198     template <int MAX>
199     struct Iterate<MAX, MAX>
200     {
201         // BucketKeys
BucketKeyscub::AgentRadixSortUpsweep::Iterate202         static __device__ __forceinline__ void BucketKeys(AgentRadixSortUpsweep &/*cta*/, UnsignedBits /*keys*/[KEYS_PER_THREAD]) {}
203     };
204 
205 
206     //---------------------------------------------------------------------
207     // Utility methods
208     //---------------------------------------------------------------------
209 
210     /**
211      * Decode a key and increment corresponding smem digit counter
212      */
Bucketcub::AgentRadixSortUpsweep213     __device__ __forceinline__ void Bucket(UnsignedBits key)
214     {
215         // Perform transform op
216         UnsignedBits converted_key = Traits<KeyT>::TwiddleIn(key);
217 
218         // Extract current digit bits
219         UnsignedBits digit = BFE(converted_key, current_bit, num_bits);
220 
221         // Get sub-counter offset
222         UnsignedBits sub_counter = digit & (PACKING_RATIO - 1);
223 
224         // Get row offset
225         UnsignedBits row_offset = digit >> LOG_PACKING_RATIO;
226 
227         // Increment counter
228         temp_storage.thread_counters[row_offset][threadIdx.x][sub_counter]++;
229     }
230 
231 
232     /**
233      * Reset composite counters
234      */
ResetDigitCounterscub::AgentRadixSortUpsweep235     __device__ __forceinline__ void ResetDigitCounters()
236     {
237         #pragma unroll
238         for (int LANE = 0; LANE < COUNTER_LANES; LANE++)
239         {
240             temp_storage.packed_thread_counters[LANE][threadIdx.x] = 0;
241         }
242     }
243 
244 
245     /**
246      * Reset the unpacked counters in each thread
247      */
ResetUnpackedCounterscub::AgentRadixSortUpsweep248     __device__ __forceinline__ void ResetUnpackedCounters()
249     {
250         #pragma unroll
251         for (int LANE = 0; LANE < LANES_PER_WARP; LANE++)
252         {
253             #pragma unroll
254             for (int UNPACKED_COUNTER = 0; UNPACKED_COUNTER < PACKING_RATIO; UNPACKED_COUNTER++)
255             {
256                 local_counts[LANE][UNPACKED_COUNTER] = 0;
257             }
258         }
259     }
260 
261 
262     /**
263      * Extracts and aggregates the digit counters for each counter lane
264      * owned by this warp
265      */
UnpackDigitCountscub::AgentRadixSortUpsweep266     __device__ __forceinline__ void UnpackDigitCounts()
267     {
268         unsigned int warp_id = threadIdx.x >> LOG_WARP_THREADS;
269         unsigned int warp_tid = LaneId();
270 
271         #pragma unroll
272         for (int LANE = 0; LANE < LANES_PER_WARP; LANE++)
273         {
274             const int counter_lane = (LANE * WARPS) + warp_id;
275             if (counter_lane < COUNTER_LANES)
276             {
277                 #pragma unroll
278                 for (int PACKED_COUNTER = 0; PACKED_COUNTER < BLOCK_THREADS; PACKED_COUNTER += WARP_THREADS)
279                 {
280                     #pragma unroll
281                     for (int UNPACKED_COUNTER = 0; UNPACKED_COUNTER < PACKING_RATIO; UNPACKED_COUNTER++)
282                     {
283                         OffsetT counter = temp_storage.thread_counters[counter_lane][warp_tid + PACKED_COUNTER][UNPACKED_COUNTER];
284                         local_counts[LANE][UNPACKED_COUNTER] += counter;
285                     }
286                 }
287             }
288         }
289     }
290 
291 
292     /**
293      * Processes a single, full tile
294      */
ProcessFullTilecub::AgentRadixSortUpsweep295     __device__ __forceinline__ void ProcessFullTile(OffsetT block_offset)
296     {
297         // Tile of keys
298         UnsignedBits keys[KEYS_PER_THREAD];
299 
300         LoadDirectStriped<BLOCK_THREADS>(threadIdx.x, d_keys_in + block_offset, keys);
301 
302         // Prevent hoisting
303         CTA_SYNC();
304 
305         // Bucket tile of keys
306         Iterate<0, KEYS_PER_THREAD>::BucketKeys(*this, keys);
307     }
308 
309 
310     /**
311      * Processes a single load (may have some threads masked off)
312      */
ProcessPartialTilecub::AgentRadixSortUpsweep313     __device__ __forceinline__ void ProcessPartialTile(
314         OffsetT block_offset,
315         const OffsetT &block_end)
316     {
317         // Process partial tile if necessary using single loads
318         block_offset += threadIdx.x;
319         while (block_offset < block_end)
320         {
321             // Load and bucket key
322             UnsignedBits key = d_keys_in[block_offset];
323             Bucket(key);
324             block_offset += BLOCK_THREADS;
325         }
326     }
327 
328 
329     //---------------------------------------------------------------------
330     // Interface
331     //---------------------------------------------------------------------
332 
333     /**
334      * Constructor
335      */
AgentRadixSortUpsweepcub::AgentRadixSortUpsweep336     __device__ __forceinline__ AgentRadixSortUpsweep(
337         TempStorage &temp_storage,
338         const KeyT  *d_keys_in,
339         int         current_bit,
340         int         num_bits)
341     :
342         temp_storage(temp_storage.Alias()),
343         d_keys_in(reinterpret_cast<const UnsignedBits*>(d_keys_in)),
344         current_bit(current_bit),
345         num_bits(num_bits)
346     {}
347 
348 
349     /**
350      * Compute radix digit histograms from a segment of input tiles.
351      */
ProcessRegioncub::AgentRadixSortUpsweep352     __device__ __forceinline__ void ProcessRegion(
353         OffsetT          block_offset,
354         const OffsetT    &block_end)
355     {
356         // Reset digit counters in smem and unpacked counters in registers
357         ResetDigitCounters();
358         ResetUnpackedCounters();
359 
360         // Unroll batches of full tiles
361         while (block_offset + UNROLLED_ELEMENTS <= block_end)
362         {
363             for (int i = 0; i < UNROLL_COUNT; ++i)
364             {
365                 ProcessFullTile(block_offset);
366                 block_offset += TILE_ITEMS;
367             }
368 
369             CTA_SYNC();
370 
371             // Aggregate back into local_count registers to prevent overflow
372             UnpackDigitCounts();
373 
374             CTA_SYNC();
375 
376             // Reset composite counters in lanes
377             ResetDigitCounters();
378         }
379 
380         // Unroll single full tiles
381         while (block_offset + TILE_ITEMS <= block_end)
382         {
383             ProcessFullTile(block_offset);
384             block_offset += TILE_ITEMS;
385         }
386 
387         // Process partial tile if necessary
388         ProcessPartialTile(
389             block_offset,
390             block_end);
391 
392         CTA_SYNC();
393 
394         // Aggregate back into local_count registers
395         UnpackDigitCounts();
396     }
397 
398 
399     /**
400      * Extract counts (saving them to the external array)
401      */
402     template <bool IS_DESCENDING>
ExtractCountscub::AgentRadixSortUpsweep403     __device__ __forceinline__ void ExtractCounts(
404         OffsetT     *counters,
405         int         bin_stride = 1,
406         int         bin_offset = 0)
407     {
408         unsigned int warp_id    = threadIdx.x >> LOG_WARP_THREADS;
409         unsigned int warp_tid   = LaneId();
410 
411         // Place unpacked digit counters in shared memory
412         #pragma unroll
413         for (int LANE = 0; LANE < LANES_PER_WARP; LANE++)
414         {
415             int counter_lane = (LANE * WARPS) + warp_id;
416             if (counter_lane < COUNTER_LANES)
417             {
418                 int digit_row = counter_lane << LOG_PACKING_RATIO;
419 
420                 #pragma unroll
421                 for (int UNPACKED_COUNTER = 0; UNPACKED_COUNTER < PACKING_RATIO; UNPACKED_COUNTER++)
422                 {
423                     int bin_idx = digit_row + UNPACKED_COUNTER;
424 
425                     temp_storage.block_counters[warp_tid][bin_idx] =
426                         local_counts[LANE][UNPACKED_COUNTER];
427                 }
428             }
429         }
430 
431         CTA_SYNC();
432 
433         // Rake-reduce bin_count reductions
434 
435         // Whole blocks
436         #pragma unroll
437         for (int BIN_BASE   = RADIX_DIGITS % BLOCK_THREADS;
438             (BIN_BASE + BLOCK_THREADS) <= RADIX_DIGITS;
439             BIN_BASE += BLOCK_THREADS)
440         {
441             int bin_idx = BIN_BASE + threadIdx.x;
442 
443             OffsetT bin_count = 0;
444             #pragma unroll
445             for (int i = 0; i < WARP_THREADS; ++i)
446                 bin_count += temp_storage.block_counters[i][bin_idx];
447 
448             if (IS_DESCENDING)
449                 bin_idx = RADIX_DIGITS - bin_idx - 1;
450 
451             counters[(bin_stride * bin_idx) + bin_offset] = bin_count;
452         }
453 
454         // Remainder
455         if ((RADIX_DIGITS % BLOCK_THREADS != 0) && (threadIdx.x < RADIX_DIGITS))
456         {
457             int bin_idx = threadIdx.x;
458 
459             OffsetT bin_count = 0;
460             #pragma unroll
461             for (int i = 0; i < WARP_THREADS; ++i)
462                 bin_count += temp_storage.block_counters[i][bin_idx];
463 
464             if (IS_DESCENDING)
465                 bin_idx = RADIX_DIGITS - bin_idx - 1;
466 
467             counters[(bin_stride * bin_idx) + bin_offset] = bin_count;
468         }
469     }
470 
471 
472     /**
473      * Extract counts
474      */
475     template <int BINS_TRACKED_PER_THREAD>
ExtractCountscub::AgentRadixSortUpsweep476     __device__ __forceinline__ void ExtractCounts(
477         OffsetT (&bin_count)[BINS_TRACKED_PER_THREAD])  ///< [out] The exclusive prefix sum for the digits [(threadIdx.x * BINS_TRACKED_PER_THREAD) ... (threadIdx.x * BINS_TRACKED_PER_THREAD) + BINS_TRACKED_PER_THREAD - 1]
478     {
479         unsigned int warp_id    = threadIdx.x >> LOG_WARP_THREADS;
480         unsigned int warp_tid   = LaneId();
481 
482         // Place unpacked digit counters in shared memory
483         #pragma unroll
484         for (int LANE = 0; LANE < LANES_PER_WARP; LANE++)
485         {
486             int counter_lane = (LANE * WARPS) + warp_id;
487             if (counter_lane < COUNTER_LANES)
488             {
489                 int digit_row = counter_lane << LOG_PACKING_RATIO;
490 
491                 #pragma unroll
492                 for (int UNPACKED_COUNTER = 0; UNPACKED_COUNTER < PACKING_RATIO; UNPACKED_COUNTER++)
493                 {
494                     int bin_idx = digit_row + UNPACKED_COUNTER;
495 
496                     temp_storage.block_counters[warp_tid][bin_idx] =
497                         local_counts[LANE][UNPACKED_COUNTER];
498                 }
499             }
500         }
501 
502         CTA_SYNC();
503 
504         // Rake-reduce bin_count reductions
505         #pragma unroll
506         for (int track = 0; track < BINS_TRACKED_PER_THREAD; ++track)
507         {
508             int bin_idx = (threadIdx.x * BINS_TRACKED_PER_THREAD) + track;
509 
510             if ((BLOCK_THREADS == RADIX_DIGITS) || (bin_idx < RADIX_DIGITS))
511             {
512                 bin_count[track] = 0;
513 
514                 #pragma unroll
515                 for (int i = 0; i < WARP_THREADS; ++i)
516                     bin_count[track] += temp_storage.block_counters[i][bin_idx];
517             }
518         }
519     }
520 
521 };
522 
523 
524 }               // CUB namespace
525 CUB_NS_POSTFIX  // Optional outer namespace(s)
526 
527