1 /******************************************************************************
2  * Copyright (c) 2011, Duane Merrill.  All rights reserved.
3  * Copyright (c) 2011-2018, NVIDIA CORPORATION.  All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are met:
7  *     * Redistributions of source code must retain the above copyright
8  *       notice, this list of conditions and the following disclaimer.
9  *     * Redistributions in binary form must reproduce the above copyright
10  *       notice, this list of conditions and the following disclaimer in the
11  *       documentation and/or other materials provided with the distribution.
12  *     * Neither the name of the NVIDIA CORPORATION nor the
13  *       names of its contributors may be used to endorse or promote products
14  *       derived from this software without specific prior written permission.
15  *
16  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
17  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19  * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
20  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26  *
27  ******************************************************************************/
28 
29 /**
30  * \file
31  * cub::AgentRle implements a stateful abstraction of CUDA thread blocks for participating in device-wide run-length-encode.
32  */
33 
34 #pragma once
35 
36 #include <iterator>
37 
38 #include "single_pass_scan_operators.cuh"
39 #include "../block/block_load.cuh"
40 #include "../block/block_store.cuh"
41 #include "../block/block_scan.cuh"
42 #include "../block/block_exchange.cuh"
43 #include "../block/block_discontinuity.cuh"
44 #include "../grid/grid_queue.cuh"
45 #include "../iterator/cache_modified_input_iterator.cuh"
46 #include "../iterator/constant_input_iterator.cuh"
47 #include "../util_namespace.cuh"
48 
49 /// Optional outer namespace(s)
50 CUB_NS_PREFIX
51 
52 /// CUB namespace
53 namespace cub {
54 
55 
56 /******************************************************************************
57  * Tuning policy types
58  ******************************************************************************/
59 
60 /**
61  * Parameterizable tuning policy type for AgentRle
62  */
63 template <
64     int                         _BLOCK_THREADS,                 ///< Threads per thread block
65     int                         _ITEMS_PER_THREAD,              ///< Items per thread (per tile of input)
66     BlockLoadAlgorithm          _LOAD_ALGORITHM,                ///< The BlockLoad algorithm to use
67     CacheLoadModifier           _LOAD_MODIFIER,                 ///< Cache load modifier for reading input elements
68     bool                        _STORE_WARP_TIME_SLICING,       ///< Whether or not only one warp's worth of shared memory should be allocated and time-sliced among block-warps during any store-related data transpositions (versus each warp having its own storage)
69     BlockScanAlgorithm          _SCAN_ALGORITHM>                ///< The BlockScan algorithm to use
70 struct AgentRlePolicy
71 {
72     enum
73     {
74         BLOCK_THREADS           = _BLOCK_THREADS,               ///< Threads per thread block
75         ITEMS_PER_THREAD        = _ITEMS_PER_THREAD,            ///< Items per thread (per tile of input)
76         STORE_WARP_TIME_SLICING = _STORE_WARP_TIME_SLICING,     ///< Whether or not only one warp's worth of shared memory should be allocated and time-sliced among block-warps during any store-related data transpositions (versus each warp having its own storage)
77     };
78 
79     static const BlockLoadAlgorithm     LOAD_ALGORITHM          = _LOAD_ALGORITHM;      ///< The BlockLoad algorithm to use
80     static const CacheLoadModifier      LOAD_MODIFIER           = _LOAD_MODIFIER;       ///< Cache load modifier for reading input elements
81     static const BlockScanAlgorithm     SCAN_ALGORITHM          = _SCAN_ALGORITHM;      ///< The BlockScan algorithm to use
82 };
83 
84 
85 
86 
87 
88 /******************************************************************************
89  * Thread block abstractions
90  ******************************************************************************/
91 
92 /**
93  * \brief AgentRle implements a stateful abstraction of CUDA thread blocks for participating in device-wide run-length-encode
94  */
95 template <
96     typename    AgentRlePolicyT,        ///< Parameterized AgentRlePolicyT tuning policy type
97     typename    InputIteratorT,         ///< Random-access input iterator type for data
98     typename    OffsetsOutputIteratorT, ///< Random-access output iterator type for offset values
99     typename    LengthsOutputIteratorT, ///< Random-access output iterator type for length values
100     typename    EqualityOpT,            ///< T equality operator type
101     typename    OffsetT>                ///< Signed integer type for global offsets
102 struct AgentRle
103 {
104     //---------------------------------------------------------------------
105     // Types and constants
106     //---------------------------------------------------------------------
107 
108     /// The input value type
109     typedef typename std::iterator_traits<InputIteratorT>::value_type T;
110 
111     /// The lengths output value type
112     typedef typename If<(Equals<typename std::iterator_traits<LengthsOutputIteratorT>::value_type, void>::VALUE),   // LengthT =  (if output iterator's value type is void) ?
113         OffsetT,                                                                                                    // ... then the OffsetT type,
114         typename std::iterator_traits<LengthsOutputIteratorT>::value_type>::Type LengthT;                           // ... else the output iterator's value type
115 
116     /// Tuple type for scanning (pairs run-length and run-index)
117     typedef KeyValuePair<OffsetT, LengthT> LengthOffsetPair;
118 
119     /// Tile status descriptor interface type
120     typedef ReduceByKeyScanTileState<LengthT, OffsetT> ScanTileStateT;
121 
122     // Constants
123     enum
124     {
125         WARP_THREADS            = CUB_WARP_THREADS(PTX_ARCH),
126         BLOCK_THREADS           = AgentRlePolicyT::BLOCK_THREADS,
127         ITEMS_PER_THREAD        = AgentRlePolicyT::ITEMS_PER_THREAD,
128         WARP_ITEMS              = WARP_THREADS * ITEMS_PER_THREAD,
129         TILE_ITEMS              = BLOCK_THREADS * ITEMS_PER_THREAD,
130         WARPS                   = (BLOCK_THREADS + WARP_THREADS - 1) / WARP_THREADS,
131 
132         /// Whether or not to sync after loading data
133         SYNC_AFTER_LOAD         = (AgentRlePolicyT::LOAD_ALGORITHM != BLOCK_LOAD_DIRECT),
134 
135         /// Whether or not only one warp's worth of shared memory should be allocated and time-sliced among block-warps during any store-related data transpositions (versus each warp having its own storage)
136         STORE_WARP_TIME_SLICING = AgentRlePolicyT::STORE_WARP_TIME_SLICING,
137         ACTIVE_EXCHANGE_WARPS   = (STORE_WARP_TIME_SLICING) ? 1 : WARPS,
138     };
139 
140 
141     /**
142      * Special operator that signals all out-of-bounds items are not equal to everything else,
143      * forcing both (1) the last item to be tail-flagged and (2) all oob items to be marked
144      * trivial.
145      */
146     template <bool LAST_TILE>
147     struct OobInequalityOp
148     {
149         OffsetT         num_remaining;
150         EqualityOpT      equality_op;
151 
OobInequalityOpcub::AgentRle::OobInequalityOp152         __device__ __forceinline__ OobInequalityOp(
153             OffsetT     num_remaining,
154             EqualityOpT  equality_op)
155         :
156             num_remaining(num_remaining),
157             equality_op(equality_op)
158         {}
159 
160         template <typename Index>
operator ()cub::AgentRle::OobInequalityOp161         __host__ __device__ __forceinline__ bool operator()(T first, T second, Index idx)
162         {
163             if (!LAST_TILE || (idx < num_remaining))
164                 return !equality_op(first, second);
165             else
166                 return true;
167         }
168     };
169 
170 
171     // Cache-modified Input iterator wrapper type (for applying cache modifier) for data
172     typedef typename If<IsPointer<InputIteratorT>::VALUE,
173             CacheModifiedInputIterator<AgentRlePolicyT::LOAD_MODIFIER, T, OffsetT>,      // Wrap the native input pointer with CacheModifiedVLengthnputIterator
174             InputIteratorT>::Type                                                       // Directly use the supplied input iterator type
175         WrappedInputIteratorT;
176 
177     // Parameterized BlockLoad type for data
178     typedef BlockLoad<
179             T,
180             AgentRlePolicyT::BLOCK_THREADS,
181             AgentRlePolicyT::ITEMS_PER_THREAD,
182             AgentRlePolicyT::LOAD_ALGORITHM>
183         BlockLoadT;
184 
185     // Parameterized BlockDiscontinuity type for data
186     typedef BlockDiscontinuity<T, BLOCK_THREADS> BlockDiscontinuityT;
187 
188     // Parameterized WarpScan type
189     typedef WarpScan<LengthOffsetPair> WarpScanPairs;
190 
191     // Reduce-length-by-run scan operator
192     typedef ReduceBySegmentOp<cub::Sum> ReduceBySegmentOpT;
193 
194     // Callback type for obtaining tile prefix during block scan
195     typedef TilePrefixCallbackOp<
196             LengthOffsetPair,
197             ReduceBySegmentOpT,
198             ScanTileStateT>
199         TilePrefixCallbackOpT;
200 
201     // Warp exchange types
202     typedef WarpExchange<LengthOffsetPair, ITEMS_PER_THREAD>        WarpExchangePairs;
203 
204     typedef typename If<STORE_WARP_TIME_SLICING, typename WarpExchangePairs::TempStorage, NullType>::Type WarpExchangePairsStorage;
205 
206     typedef WarpExchange<OffsetT, ITEMS_PER_THREAD>                 WarpExchangeOffsets;
207     typedef WarpExchange<LengthT, ITEMS_PER_THREAD>                 WarpExchangeLengths;
208 
209     typedef LengthOffsetPair WarpAggregates[WARPS];
210 
211     // Shared memory type for this thread block
212     struct _TempStorage
213     {
214         // Aliasable storage layout
215         union Aliasable
216         {
217             struct
218             {
219                 typename BlockDiscontinuityT::TempStorage       discontinuity;              // Smem needed for discontinuity detection
220                 typename WarpScanPairs::TempStorage             warp_scan[WARPS];           // Smem needed for warp-synchronous scans
221                 Uninitialized<LengthOffsetPair[WARPS]>          warp_aggregates;            // Smem needed for sharing warp-wide aggregates
222                 typename TilePrefixCallbackOpT::TempStorage     prefix;                     // Smem needed for cooperative prefix callback
223             };
224 
225             // Smem needed for input loading
226             typename BlockLoadT::TempStorage                    load;
227 
228             // Aliasable layout needed for two-phase scatter
229             union ScatterAliasable
230             {
231                 unsigned long long                              align;
232                 WarpExchangePairsStorage                        exchange_pairs[ACTIVE_EXCHANGE_WARPS];
233                 typename WarpExchangeOffsets::TempStorage       exchange_offsets[ACTIVE_EXCHANGE_WARPS];
234                 typename WarpExchangeLengths::TempStorage       exchange_lengths[ACTIVE_EXCHANGE_WARPS];
235 
236             } scatter_aliasable;
237 
238         } aliasable;
239 
240         OffsetT             tile_idx;                   // Shared tile index
241         LengthOffsetPair    tile_inclusive;             // Inclusive tile prefix
242         LengthOffsetPair    tile_exclusive;             // Exclusive tile prefix
243     };
244 
245     // Alias wrapper allowing storage to be unioned
246     struct TempStorage : Uninitialized<_TempStorage> {};
247 
248 
249     //---------------------------------------------------------------------
250     // Per-thread fields
251     //---------------------------------------------------------------------
252 
253     _TempStorage&                   temp_storage;       ///< Reference to temp_storage
254 
255     WrappedInputIteratorT           d_in;               ///< Pointer to input sequence of data items
256     OffsetsOutputIteratorT          d_offsets_out;      ///< Input run offsets
257     LengthsOutputIteratorT          d_lengths_out;      ///< Output run lengths
258 
259     EqualityOpT                     equality_op;        ///< T equality operator
260     ReduceBySegmentOpT              scan_op;            ///< Reduce-length-by-flag scan operator
261     OffsetT                         num_items;          ///< Total number of input items
262 
263 
264     //---------------------------------------------------------------------
265     // Constructor
266     //---------------------------------------------------------------------
267 
268     // Constructor
269     __device__ __forceinline__
AgentRlecub::AgentRle270     AgentRle(
271         TempStorage                 &temp_storage,      ///< [in] Reference to temp_storage
272         InputIteratorT              d_in,               ///< [in] Pointer to input sequence of data items
273         OffsetsOutputIteratorT      d_offsets_out,      ///< [out] Pointer to output sequence of run offsets
274         LengthsOutputIteratorT      d_lengths_out,      ///< [out] Pointer to output sequence of run lengths
275         EqualityOpT                 equality_op,        ///< [in] T equality operator
276         OffsetT                     num_items)          ///< [in] Total number of input items
277     :
278         temp_storage(temp_storage.Alias()),
279         d_in(d_in),
280         d_offsets_out(d_offsets_out),
281         d_lengths_out(d_lengths_out),
282         equality_op(equality_op),
283         scan_op(cub::Sum()),
284         num_items(num_items)
285     {}
286 
287 
288     //---------------------------------------------------------------------
289     // Utility methods for initializing the selections
290     //---------------------------------------------------------------------
291 
292     template <bool FIRST_TILE, bool LAST_TILE>
InitializeSelectionscub::AgentRle293     __device__ __forceinline__ void InitializeSelections(
294         OffsetT             tile_offset,
295         OffsetT             num_remaining,
296         T                   (&items)[ITEMS_PER_THREAD],
297         LengthOffsetPair    (&lengths_and_num_runs)[ITEMS_PER_THREAD])
298     {
299         bool                head_flags[ITEMS_PER_THREAD];
300         bool                tail_flags[ITEMS_PER_THREAD];
301 
302         OobInequalityOp<LAST_TILE> inequality_op(num_remaining, equality_op);
303 
304         if (FIRST_TILE && LAST_TILE)
305         {
306             // First-and-last-tile always head-flags the first item and tail-flags the last item
307 
308             BlockDiscontinuityT(temp_storage.aliasable.discontinuity).FlagHeadsAndTails(
309                 head_flags, tail_flags, items, inequality_op);
310         }
311         else if (FIRST_TILE)
312         {
313             // First-tile always head-flags the first item
314 
315             // Get the first item from the next tile
316             T tile_successor_item;
317             if (threadIdx.x == BLOCK_THREADS - 1)
318                 tile_successor_item = d_in[tile_offset + TILE_ITEMS];
319 
320             BlockDiscontinuityT(temp_storage.aliasable.discontinuity).FlagHeadsAndTails(
321                 head_flags, tail_flags, tile_successor_item, items, inequality_op);
322         }
323         else if (LAST_TILE)
324         {
325             // Last-tile always flags the last item
326 
327             // Get the last item from the previous tile
328             T tile_predecessor_item;
329             if (threadIdx.x == 0)
330                 tile_predecessor_item = d_in[tile_offset - 1];
331 
332             BlockDiscontinuityT(temp_storage.aliasable.discontinuity).FlagHeadsAndTails(
333                 head_flags, tile_predecessor_item, tail_flags, items, inequality_op);
334         }
335         else
336         {
337             // Get the first item from the next tile
338             T tile_successor_item;
339             if (threadIdx.x == BLOCK_THREADS - 1)
340                 tile_successor_item = d_in[tile_offset + TILE_ITEMS];
341 
342             // Get the last item from the previous tile
343             T tile_predecessor_item;
344             if (threadIdx.x == 0)
345                 tile_predecessor_item = d_in[tile_offset - 1];
346 
347             BlockDiscontinuityT(temp_storage.aliasable.discontinuity).FlagHeadsAndTails(
348                 head_flags, tile_predecessor_item, tail_flags, tile_successor_item, items, inequality_op);
349         }
350 
351         // Zip counts and runs
352         #pragma unroll
353         for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
354         {
355             lengths_and_num_runs[ITEM].key      = head_flags[ITEM] && (!tail_flags[ITEM]);
356             lengths_and_num_runs[ITEM].value    = ((!head_flags[ITEM]) || (!tail_flags[ITEM]));
357         }
358     }
359 
360     //---------------------------------------------------------------------
361     // Scan utility methods
362     //---------------------------------------------------------------------
363 
364     /**
365      * Scan of allocations
366      */
WarpScanAllocationscub::AgentRle367     __device__ __forceinline__ void WarpScanAllocations(
368         LengthOffsetPair    &tile_aggregate,
369         LengthOffsetPair    &warp_aggregate,
370         LengthOffsetPair    &warp_exclusive_in_tile,
371         LengthOffsetPair    &thread_exclusive_in_warp,
372         LengthOffsetPair    (&lengths_and_num_runs)[ITEMS_PER_THREAD])
373     {
374         // Perform warpscans
375         unsigned int warp_id = ((WARPS == 1) ? 0 : threadIdx.x / WARP_THREADS);
376         int lane_id = LaneId();
377 
378         LengthOffsetPair identity;
379         identity.key = 0;
380         identity.value = 0;
381 
382         LengthOffsetPair thread_inclusive;
383         LengthOffsetPair thread_aggregate = internal::ThreadReduce(lengths_and_num_runs, scan_op);
384         WarpScanPairs(temp_storage.aliasable.warp_scan[warp_id]).Scan(
385             thread_aggregate,
386             thread_inclusive,
387             thread_exclusive_in_warp,
388             identity,
389             scan_op);
390 
391         // Last lane in each warp shares its warp-aggregate
392         if (lane_id == WARP_THREADS - 1)
393             temp_storage.aliasable.warp_aggregates.Alias()[warp_id] = thread_inclusive;
394 
395         CTA_SYNC();
396 
397         // Accumulate total selected and the warp-wide prefix
398         warp_exclusive_in_tile          = identity;
399         warp_aggregate                  = temp_storage.aliasable.warp_aggregates.Alias()[warp_id];
400         tile_aggregate                  = temp_storage.aliasable.warp_aggregates.Alias()[0];
401 
402         #pragma unroll
403         for (int WARP = 1; WARP < WARPS; ++WARP)
404         {
405             if (warp_id == WARP)
406                 warp_exclusive_in_tile = tile_aggregate;
407 
408             tile_aggregate = scan_op(tile_aggregate, temp_storage.aliasable.warp_aggregates.Alias()[WARP]);
409         }
410     }
411 
412 
413     //---------------------------------------------------------------------
414     // Utility methods for scattering selections
415     //---------------------------------------------------------------------
416 
417     /**
418      * Two-phase scatter, specialized for warp time-slicing
419      */
420     template <bool FIRST_TILE>
ScatterTwoPhasecub::AgentRle421     __device__ __forceinline__ void ScatterTwoPhase(
422         OffsetT             tile_num_runs_exclusive_in_global,
423         OffsetT             warp_num_runs_aggregate,
424         OffsetT             warp_num_runs_exclusive_in_tile,
425         OffsetT             (&thread_num_runs_exclusive_in_warp)[ITEMS_PER_THREAD],
426         LengthOffsetPair    (&lengths_and_offsets)[ITEMS_PER_THREAD],
427         Int2Type<true>      is_warp_time_slice)
428     {
429         unsigned int warp_id = ((WARPS == 1) ? 0 : threadIdx.x / WARP_THREADS);
430         int lane_id = LaneId();
431 
432         // Locally compact items within the warp (first warp)
433         if (warp_id == 0)
434         {
435             WarpExchangePairs(temp_storage.aliasable.scatter_aliasable.exchange_pairs[0]).ScatterToStriped(
436                 lengths_and_offsets, thread_num_runs_exclusive_in_warp);
437         }
438 
439         // Locally compact items within the warp (remaining warps)
440         #pragma unroll
441         for (int SLICE = 1; SLICE < WARPS; ++SLICE)
442         {
443             CTA_SYNC();
444 
445             if (warp_id == SLICE)
446             {
447                 WarpExchangePairs(temp_storage.aliasable.scatter_aliasable.exchange_pairs[0]).ScatterToStriped(
448                     lengths_and_offsets, thread_num_runs_exclusive_in_warp);
449             }
450         }
451 
452         // Global scatter
453         #pragma unroll
454         for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
455         {
456             if ((ITEM * WARP_THREADS) < warp_num_runs_aggregate - lane_id)
457             {
458                 OffsetT item_offset =
459                     tile_num_runs_exclusive_in_global +
460                     warp_num_runs_exclusive_in_tile +
461                     (ITEM * WARP_THREADS) + lane_id;
462 
463                 // Scatter offset
464                 d_offsets_out[item_offset] = lengths_and_offsets[ITEM].key;
465 
466                 // Scatter length if not the first (global) length
467                 if ((!FIRST_TILE) || (ITEM != 0) || (threadIdx.x > 0))
468                 {
469                     d_lengths_out[item_offset - 1] = lengths_and_offsets[ITEM].value;
470                 }
471             }
472         }
473     }
474 
475 
476     /**
477      * Two-phase scatter
478      */
479     template <bool FIRST_TILE>
ScatterTwoPhasecub::AgentRle480     __device__ __forceinline__ void ScatterTwoPhase(
481         OffsetT             tile_num_runs_exclusive_in_global,
482         OffsetT             warp_num_runs_aggregate,
483         OffsetT             warp_num_runs_exclusive_in_tile,
484         OffsetT             (&thread_num_runs_exclusive_in_warp)[ITEMS_PER_THREAD],
485         LengthOffsetPair    (&lengths_and_offsets)[ITEMS_PER_THREAD],
486         Int2Type<false>     is_warp_time_slice)
487     {
488         unsigned int warp_id = ((WARPS == 1) ? 0 : threadIdx.x / WARP_THREADS);
489         int lane_id = LaneId();
490 
491         // Unzip
492         OffsetT run_offsets[ITEMS_PER_THREAD];
493         LengthT run_lengths[ITEMS_PER_THREAD];
494 
495         #pragma unroll
496         for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
497         {
498             run_offsets[ITEM] = lengths_and_offsets[ITEM].key;
499             run_lengths[ITEM] = lengths_and_offsets[ITEM].value;
500         }
501 
502         WarpExchangeOffsets(temp_storage.aliasable.scatter_aliasable.exchange_offsets[warp_id]).ScatterToStriped(
503             run_offsets, thread_num_runs_exclusive_in_warp);
504 
505         WARP_SYNC(0xffffffff);
506 
507         WarpExchangeLengths(temp_storage.aliasable.scatter_aliasable.exchange_lengths[warp_id]).ScatterToStriped(
508             run_lengths, thread_num_runs_exclusive_in_warp);
509 
510         // Global scatter
511         #pragma unroll
512         for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
513         {
514             if ((ITEM * WARP_THREADS) + lane_id < warp_num_runs_aggregate)
515             {
516                 OffsetT item_offset =
517                     tile_num_runs_exclusive_in_global +
518                     warp_num_runs_exclusive_in_tile +
519                     (ITEM * WARP_THREADS) + lane_id;
520 
521                 // Scatter offset
522                 d_offsets_out[item_offset] = run_offsets[ITEM];
523 
524                 // Scatter length if not the first (global) length
525                 if ((!FIRST_TILE) || (ITEM != 0) || (threadIdx.x > 0))
526                 {
527                     d_lengths_out[item_offset - 1] = run_lengths[ITEM];
528                 }
529             }
530         }
531     }
532 
533 
534     /**
535      * Direct scatter
536      */
537     template <bool FIRST_TILE>
ScatterDirectcub::AgentRle538     __device__ __forceinline__ void ScatterDirect(
539         OffsetT             tile_num_runs_exclusive_in_global,
540         OffsetT             warp_num_runs_aggregate,
541         OffsetT             warp_num_runs_exclusive_in_tile,
542         OffsetT             (&thread_num_runs_exclusive_in_warp)[ITEMS_PER_THREAD],
543         LengthOffsetPair    (&lengths_and_offsets)[ITEMS_PER_THREAD])
544     {
545         #pragma unroll
546         for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
547         {
548             if (thread_num_runs_exclusive_in_warp[ITEM] < warp_num_runs_aggregate)
549             {
550                 OffsetT item_offset =
551                     tile_num_runs_exclusive_in_global +
552                     warp_num_runs_exclusive_in_tile +
553                     thread_num_runs_exclusive_in_warp[ITEM];
554 
555                 // Scatter offset
556                 d_offsets_out[item_offset] = lengths_and_offsets[ITEM].key;
557 
558                 // Scatter length if not the first (global) length
559                 if (item_offset >= 1)
560                 {
561                     d_lengths_out[item_offset - 1] = lengths_and_offsets[ITEM].value;
562                 }
563             }
564         }
565     }
566 
567 
568     /**
569      * Scatter
570      */
571     template <bool FIRST_TILE>
Scattercub::AgentRle572     __device__ __forceinline__ void Scatter(
573         OffsetT             tile_num_runs_aggregate,
574         OffsetT             tile_num_runs_exclusive_in_global,
575         OffsetT             warp_num_runs_aggregate,
576         OffsetT             warp_num_runs_exclusive_in_tile,
577         OffsetT             (&thread_num_runs_exclusive_in_warp)[ITEMS_PER_THREAD],
578         LengthOffsetPair    (&lengths_and_offsets)[ITEMS_PER_THREAD])
579     {
580         if ((ITEMS_PER_THREAD == 1) || (tile_num_runs_aggregate < BLOCK_THREADS))
581         {
582             // Direct scatter if the warp has any items
583             if (warp_num_runs_aggregate)
584             {
585                 ScatterDirect<FIRST_TILE>(
586                     tile_num_runs_exclusive_in_global,
587                     warp_num_runs_aggregate,
588                     warp_num_runs_exclusive_in_tile,
589                     thread_num_runs_exclusive_in_warp,
590                     lengths_and_offsets);
591             }
592         }
593         else
594         {
595             // Scatter two phase
596             ScatterTwoPhase<FIRST_TILE>(
597                 tile_num_runs_exclusive_in_global,
598                 warp_num_runs_aggregate,
599                 warp_num_runs_exclusive_in_tile,
600                 thread_num_runs_exclusive_in_warp,
601                 lengths_and_offsets,
602                 Int2Type<STORE_WARP_TIME_SLICING>());
603         }
604     }
605 
606 
607 
608     //---------------------------------------------------------------------
609     // Cooperatively scan a device-wide sequence of tiles with other CTAs
610     //---------------------------------------------------------------------
611 
612     /**
613      * Process a tile of input (dynamic chained scan)
614      */
615     template <
616         bool                LAST_TILE>
ConsumeTilecub::AgentRle617     __device__ __forceinline__ LengthOffsetPair ConsumeTile(
618         OffsetT             num_items,          ///< Total number of global input items
619         OffsetT             num_remaining,      ///< Number of global input items remaining (including this tile)
620         int                 tile_idx,           ///< Tile index
621         OffsetT             tile_offset,        ///< Tile offset
622         ScanTileStateT      &tile_status)       ///< Global list of tile status
623     {
624         if (tile_idx == 0)
625         {
626             // First tile
627 
628             // Load items
629             T items[ITEMS_PER_THREAD];
630             if (LAST_TILE)
631                 BlockLoadT(temp_storage.aliasable.load).Load(d_in + tile_offset, items, num_remaining, T());
632             else
633                 BlockLoadT(temp_storage.aliasable.load).Load(d_in + tile_offset, items);
634 
635             if (SYNC_AFTER_LOAD)
636                 CTA_SYNC();
637 
638             // Set flags
639             LengthOffsetPair    lengths_and_num_runs[ITEMS_PER_THREAD];
640 
641             InitializeSelections<true, LAST_TILE>(
642                 tile_offset,
643                 num_remaining,
644                 items,
645                 lengths_and_num_runs);
646 
647             // Exclusive scan of lengths and runs
648             LengthOffsetPair tile_aggregate;
649             LengthOffsetPair warp_aggregate;
650             LengthOffsetPair warp_exclusive_in_tile;
651             LengthOffsetPair thread_exclusive_in_warp;
652 
653             WarpScanAllocations(
654                 tile_aggregate,
655                 warp_aggregate,
656                 warp_exclusive_in_tile,
657                 thread_exclusive_in_warp,
658                 lengths_and_num_runs);
659 
660             // Update tile status if this is not the last tile
661             if (!LAST_TILE && (threadIdx.x == 0))
662                 tile_status.SetInclusive(0, tile_aggregate);
663 
664             // Update thread_exclusive_in_warp to fold in warp run-length
665             if (thread_exclusive_in_warp.key == 0)
666                 thread_exclusive_in_warp.value += warp_exclusive_in_tile.value;
667 
668             LengthOffsetPair    lengths_and_offsets[ITEMS_PER_THREAD];
669             OffsetT             thread_num_runs_exclusive_in_warp[ITEMS_PER_THREAD];
670             LengthOffsetPair    lengths_and_num_runs2[ITEMS_PER_THREAD];
671 
672             // Downsweep scan through lengths_and_num_runs
673             internal::ThreadScanExclusive(lengths_and_num_runs, lengths_and_num_runs2, scan_op, thread_exclusive_in_warp);
674 
675             // Zip
676 
677             #pragma unroll
678             for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
679             {
680                 lengths_and_offsets[ITEM].value         = lengths_and_num_runs2[ITEM].value;
681                 lengths_and_offsets[ITEM].key        = tile_offset + (threadIdx.x * ITEMS_PER_THREAD) + ITEM;
682                 thread_num_runs_exclusive_in_warp[ITEM] = (lengths_and_num_runs[ITEM].key) ?
683                                                                 lengths_and_num_runs2[ITEM].key :         // keep
684                                                                 WARP_THREADS * ITEMS_PER_THREAD;            // discard
685             }
686 
687             OffsetT tile_num_runs_aggregate              = tile_aggregate.key;
688             OffsetT tile_num_runs_exclusive_in_global    = 0;
689             OffsetT warp_num_runs_aggregate              = warp_aggregate.key;
690             OffsetT warp_num_runs_exclusive_in_tile      = warp_exclusive_in_tile.key;
691 
692             // Scatter
693             Scatter<true>(
694                 tile_num_runs_aggregate,
695                 tile_num_runs_exclusive_in_global,
696                 warp_num_runs_aggregate,
697                 warp_num_runs_exclusive_in_tile,
698                 thread_num_runs_exclusive_in_warp,
699                 lengths_and_offsets);
700 
701             // Return running total (inclusive of this tile)
702             return tile_aggregate;
703         }
704         else
705         {
706             // Not first tile
707 
708             // Load items
709             T items[ITEMS_PER_THREAD];
710             if (LAST_TILE)
711                 BlockLoadT(temp_storage.aliasable.load).Load(d_in + tile_offset, items, num_remaining, T());
712             else
713                 BlockLoadT(temp_storage.aliasable.load).Load(d_in + tile_offset, items);
714 
715             if (SYNC_AFTER_LOAD)
716                 CTA_SYNC();
717 
718             // Set flags
719             LengthOffsetPair    lengths_and_num_runs[ITEMS_PER_THREAD];
720 
721             InitializeSelections<false, LAST_TILE>(
722                 tile_offset,
723                 num_remaining,
724                 items,
725                 lengths_and_num_runs);
726 
727             // Exclusive scan of lengths and runs
728             LengthOffsetPair tile_aggregate;
729             LengthOffsetPair warp_aggregate;
730             LengthOffsetPair warp_exclusive_in_tile;
731             LengthOffsetPair thread_exclusive_in_warp;
732 
733             WarpScanAllocations(
734                 tile_aggregate,
735                 warp_aggregate,
736                 warp_exclusive_in_tile,
737                 thread_exclusive_in_warp,
738                 lengths_and_num_runs);
739 
740             // First warp computes tile prefix in lane 0
741             TilePrefixCallbackOpT prefix_op(tile_status, temp_storage.aliasable.prefix, Sum(), tile_idx);
742             unsigned int warp_id = ((WARPS == 1) ? 0 : threadIdx.x / WARP_THREADS);
743             if (warp_id == 0)
744             {
745                 prefix_op(tile_aggregate);
746                 if (threadIdx.x == 0)
747                     temp_storage.tile_exclusive = prefix_op.exclusive_prefix;
748             }
749 
750             CTA_SYNC();
751 
752             LengthOffsetPair tile_exclusive_in_global = temp_storage.tile_exclusive;
753 
754             // Update thread_exclusive_in_warp to fold in warp and tile run-lengths
755             LengthOffsetPair thread_exclusive = scan_op(tile_exclusive_in_global, warp_exclusive_in_tile);
756             if (thread_exclusive_in_warp.key == 0)
757                 thread_exclusive_in_warp.value += thread_exclusive.value;
758 
759             // Downsweep scan through lengths_and_num_runs
760             LengthOffsetPair    lengths_and_num_runs2[ITEMS_PER_THREAD];
761             LengthOffsetPair    lengths_and_offsets[ITEMS_PER_THREAD];
762             OffsetT             thread_num_runs_exclusive_in_warp[ITEMS_PER_THREAD];
763 
764             internal::ThreadScanExclusive(lengths_and_num_runs, lengths_and_num_runs2, scan_op, thread_exclusive_in_warp);
765 
766             // Zip
767             #pragma unroll
768             for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
769             {
770                 lengths_and_offsets[ITEM].value         = lengths_and_num_runs2[ITEM].value;
771                 lengths_and_offsets[ITEM].key        = tile_offset + (threadIdx.x * ITEMS_PER_THREAD) + ITEM;
772                 thread_num_runs_exclusive_in_warp[ITEM] = (lengths_and_num_runs[ITEM].key) ?
773                                                                 lengths_and_num_runs2[ITEM].key :         // keep
774                                                                 WARP_THREADS * ITEMS_PER_THREAD;            // discard
775             }
776 
777             OffsetT tile_num_runs_aggregate              = tile_aggregate.key;
778             OffsetT tile_num_runs_exclusive_in_global    = tile_exclusive_in_global.key;
779             OffsetT warp_num_runs_aggregate              = warp_aggregate.key;
780             OffsetT warp_num_runs_exclusive_in_tile      = warp_exclusive_in_tile.key;
781 
782             // Scatter
783             Scatter<false>(
784                 tile_num_runs_aggregate,
785                 tile_num_runs_exclusive_in_global,
786                 warp_num_runs_aggregate,
787                 warp_num_runs_exclusive_in_tile,
788                 thread_num_runs_exclusive_in_warp,
789                 lengths_and_offsets);
790 
791             // Return running total (inclusive of this tile)
792             return prefix_op.inclusive_prefix;
793         }
794     }
795 
796 
797     /**
798      * Scan tiles of items as part of a dynamic chained scan
799      */
800     template <typename NumRunsIteratorT>            ///< Output iterator type for recording number of items selected
ConsumeRangecub::AgentRle801     __device__ __forceinline__ void ConsumeRange(
802         int                 num_tiles,              ///< Total number of input tiles
803         ScanTileStateT&     tile_status,            ///< Global list of tile status
804         NumRunsIteratorT    d_num_runs_out)         ///< Output pointer for total number of runs identified
805     {
806         // Blocks are launched in increasing order, so just assign one tile per block
807         int     tile_idx        = (blockIdx.x * gridDim.y) + blockIdx.y;    // Current tile index
808         OffsetT tile_offset     = tile_idx * TILE_ITEMS;                  // Global offset for the current tile
809         OffsetT num_remaining   = num_items - tile_offset;                  // Remaining items (including this tile)
810 
811         if (tile_idx < num_tiles - 1)
812         {
813             // Not the last tile (full)
814             ConsumeTile<false>(num_items, num_remaining, tile_idx, tile_offset, tile_status);
815         }
816         else if (num_remaining > 0)
817         {
818             // The last tile (possibly partially-full)
819             LengthOffsetPair running_total = ConsumeTile<true>(num_items, num_remaining, tile_idx, tile_offset, tile_status);
820 
821             if (threadIdx.x == 0)
822             {
823                 // Output the total number of items selected
824                 *d_num_runs_out = running_total.key;
825 
826                 // The inclusive prefix contains accumulated length reduction for the last run
827                 if (running_total.key > 0)
828                     d_lengths_out[running_total.key - 1] = running_total.value;
829             }
830         }
831     }
832 };
833 
834 
835 }               // CUB namespace
836 CUB_NS_POSTFIX  // Optional outer namespace(s)
837 
838