1 /******************************************************************************
2  * Copyright (c) 2011, Duane Merrill.  All rights reserved.
3  * Copyright (c) 2011-2018, NVIDIA CORPORATION.  All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are met:
7  *     * Redistributions of source code must retain the above copyright
8  *       notice, this list of conditions and the following disclaimer.
9  *     * Redistributions in binary form must reproduce the above copyright
10  *       notice, this list of conditions and the following disclaimer in the
11  *       documentation and/or other materials provided with the distribution.
12  *     * Neither the name of the NVIDIA CORPORATION nor the
13  *       names of its contributors may be used to endorse or promote products
14  *       derived from this software without specific prior written permission.
15  *
16  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
17  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19  * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
20  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26  *
27  ******************************************************************************/
28 
29 /**
30  * \file
31  * cub::AgentReduceByKey implements a stateful abstraction of CUDA thread blocks for participating in device-wide reduce-value-by-key.
32  */
33 
34 #pragma once
35 
36 #include <iterator>
37 
38 #include "single_pass_scan_operators.cuh"
39 #include "../block/block_load.cuh"
40 #include "../block/block_store.cuh"
41 #include "../block/block_scan.cuh"
42 #include "../block/block_discontinuity.cuh"
43 #include "../iterator/cache_modified_input_iterator.cuh"
44 #include "../iterator/constant_input_iterator.cuh"
45 #include "../util_namespace.cuh"
46 
47 /// Optional outer namespace(s)
48 CUB_NS_PREFIX
49 
50 /// CUB namespace
51 namespace cub {
52 
53 
54 /******************************************************************************
55  * Tuning policy types
56  ******************************************************************************/
57 
58 /**
59  * Parameterizable tuning policy type for AgentReduceByKey
60  */
61 template <
62     int                         _BLOCK_THREADS,                 ///< Threads per thread block
63     int                         _ITEMS_PER_THREAD,              ///< Items per thread (per tile of input)
64     BlockLoadAlgorithm          _LOAD_ALGORITHM,                ///< The BlockLoad algorithm to use
65     CacheLoadModifier           _LOAD_MODIFIER,                 ///< Cache load modifier for reading input elements
66     BlockScanAlgorithm          _SCAN_ALGORITHM>                ///< The BlockScan algorithm to use
67 struct AgentReduceByKeyPolicy
68 {
69     enum
70     {
71         BLOCK_THREADS           = _BLOCK_THREADS,               ///< Threads per thread block
72         ITEMS_PER_THREAD        = _ITEMS_PER_THREAD,            ///< Items per thread (per tile of input)
73     };
74 
75     static const BlockLoadAlgorithm     LOAD_ALGORITHM          = _LOAD_ALGORITHM;      ///< The BlockLoad algorithm to use
76     static const CacheLoadModifier      LOAD_MODIFIER           = _LOAD_MODIFIER;       ///< Cache load modifier for reading input elements
77     static const BlockScanAlgorithm     SCAN_ALGORITHM          = _SCAN_ALGORITHM;      ///< The BlockScan algorithm to use
78 };
79 
80 
81 /******************************************************************************
82  * Thread block abstractions
83  ******************************************************************************/
84 
85 /**
86  * \brief AgentReduceByKey implements a stateful abstraction of CUDA thread blocks for participating in device-wide reduce-value-by-key
87  */
88 template <
89     typename    AgentReduceByKeyPolicyT,        ///< Parameterized AgentReduceByKeyPolicy tuning policy type
90     typename    KeysInputIteratorT,             ///< Random-access input iterator type for keys
91     typename    UniqueOutputIteratorT,          ///< Random-access output iterator type for keys
92     typename    ValuesInputIteratorT,           ///< Random-access input iterator type for values
93     typename    AggregatesOutputIteratorT,      ///< Random-access output iterator type for values
94     typename    NumRunsOutputIteratorT,         ///< Output iterator type for recording number of items selected
95     typename    EqualityOpT,                    ///< KeyT equality operator type
96     typename    ReductionOpT,                   ///< ValueT reduction operator type
97     typename    OffsetT>                        ///< Signed integer type for global offsets
98 struct AgentReduceByKey
99 {
100     //---------------------------------------------------------------------
101     // Types and constants
102     //---------------------------------------------------------------------
103 
104     // The input keys type
105     typedef typename std::iterator_traits<KeysInputIteratorT>::value_type KeyInputT;
106 
107     // The output keys type
108     typedef typename If<(Equals<typename std::iterator_traits<UniqueOutputIteratorT>::value_type, void>::VALUE),    // KeyOutputT =  (if output iterator's value type is void) ?
109         typename std::iterator_traits<KeysInputIteratorT>::value_type,                                              // ... then the input iterator's value type,
110         typename std::iterator_traits<UniqueOutputIteratorT>::value_type>::Type KeyOutputT;                         // ... else the output iterator's value type
111 
112     // The input values type
113     typedef typename std::iterator_traits<ValuesInputIteratorT>::value_type ValueInputT;
114 
115     // The output values type
116     typedef typename If<(Equals<typename std::iterator_traits<AggregatesOutputIteratorT>::value_type, void>::VALUE),    // ValueOutputT =  (if output iterator's value type is void) ?
117         typename std::iterator_traits<ValuesInputIteratorT>::value_type,                                                // ... then the input iterator's value type,
118         typename std::iterator_traits<AggregatesOutputIteratorT>::value_type>::Type ValueOutputT;                       // ... else the output iterator's value type
119 
120     // Tuple type for scanning (pairs accumulated segment-value with segment-index)
121     typedef KeyValuePair<OffsetT, ValueOutputT> OffsetValuePairT;
122 
123     // Tuple type for pairing keys and values
124     typedef KeyValuePair<KeyOutputT, ValueOutputT> KeyValuePairT;
125 
126     // Tile status descriptor interface type
127     typedef ReduceByKeyScanTileState<ValueOutputT, OffsetT> ScanTileStateT;
128 
129     // Guarded inequality functor
130     template <typename _EqualityOpT>
131     struct GuardedInequalityWrapper
132     {
133         _EqualityOpT     op;             ///< Wrapped equality operator
134         int             num_remaining;  ///< Items remaining
135 
136         /// Constructor
137         __host__ __device__ __forceinline__
GuardedInequalityWrappercub::AgentReduceByKey::GuardedInequalityWrapper138         GuardedInequalityWrapper(_EqualityOpT op, int num_remaining) : op(op), num_remaining(num_remaining) {}
139 
140         /// Boolean inequality operator, returns <tt>(a != b)</tt>
141         template <typename T>
operator ()cub::AgentReduceByKey::GuardedInequalityWrapper142         __host__ __device__ __forceinline__ bool operator()(const T &a, const T &b, int idx) const
143         {
144             if (idx < num_remaining)
145                 return !op(a, b);   // In bounds
146 
147             // Return true if first out-of-bounds item, false otherwise
148             return (idx == num_remaining);
149        }
150     };
151 
152 
153     // Constants
154     enum
155     {
156         BLOCK_THREADS       = AgentReduceByKeyPolicyT::BLOCK_THREADS,
157         ITEMS_PER_THREAD    = AgentReduceByKeyPolicyT::ITEMS_PER_THREAD,
158         TILE_ITEMS          = BLOCK_THREADS * ITEMS_PER_THREAD,
159         TWO_PHASE_SCATTER   = (ITEMS_PER_THREAD > 1),
160 
161         // Whether or not the scan operation has a zero-valued identity value (true if we're performing addition on a primitive type)
162         HAS_IDENTITY_ZERO   = (Equals<ReductionOpT, cub::Sum>::VALUE) && (Traits<ValueOutputT>::PRIMITIVE),
163     };
164 
165     // Cache-modified Input iterator wrapper type (for applying cache modifier) for keys
166     typedef typename If<IsPointer<KeysInputIteratorT>::VALUE,
167             CacheModifiedInputIterator<AgentReduceByKeyPolicyT::LOAD_MODIFIER, KeyInputT, OffsetT>,     // Wrap the native input pointer with CacheModifiedValuesInputIterator
168             KeysInputIteratorT>::Type                                                                   // Directly use the supplied input iterator type
169         WrappedKeysInputIteratorT;
170 
171     // Cache-modified Input iterator wrapper type (for applying cache modifier) for values
172     typedef typename If<IsPointer<ValuesInputIteratorT>::VALUE,
173             CacheModifiedInputIterator<AgentReduceByKeyPolicyT::LOAD_MODIFIER, ValueInputT, OffsetT>,   // Wrap the native input pointer with CacheModifiedValuesInputIterator
174             ValuesInputIteratorT>::Type                                                                 // Directly use the supplied input iterator type
175         WrappedValuesInputIteratorT;
176 
177     // Cache-modified Input iterator wrapper type (for applying cache modifier) for fixup values
178     typedef typename If<IsPointer<AggregatesOutputIteratorT>::VALUE,
179             CacheModifiedInputIterator<AgentReduceByKeyPolicyT::LOAD_MODIFIER, ValueInputT, OffsetT>,   // Wrap the native input pointer with CacheModifiedValuesInputIterator
180             AggregatesOutputIteratorT>::Type                                                            // Directly use the supplied input iterator type
181         WrappedFixupInputIteratorT;
182 
183     // Reduce-value-by-segment scan operator
184     typedef ReduceBySegmentOp<ReductionOpT> ReduceBySegmentOpT;
185 
186     // Parameterized BlockLoad type for keys
187     typedef BlockLoad<
188             KeyOutputT,
189             BLOCK_THREADS,
190             ITEMS_PER_THREAD,
191             AgentReduceByKeyPolicyT::LOAD_ALGORITHM>
192         BlockLoadKeysT;
193 
194     // Parameterized BlockLoad type for values
195     typedef BlockLoad<
196             ValueOutputT,
197             BLOCK_THREADS,
198             ITEMS_PER_THREAD,
199             AgentReduceByKeyPolicyT::LOAD_ALGORITHM>
200         BlockLoadValuesT;
201 
202     // Parameterized BlockDiscontinuity type for keys
203     typedef BlockDiscontinuity<
204             KeyOutputT,
205             BLOCK_THREADS>
206         BlockDiscontinuityKeys;
207 
208     // Parameterized BlockScan type
209     typedef BlockScan<
210             OffsetValuePairT,
211             BLOCK_THREADS,
212             AgentReduceByKeyPolicyT::SCAN_ALGORITHM>
213         BlockScanT;
214 
215     // Callback type for obtaining tile prefix during block scan
216     typedef TilePrefixCallbackOp<
217             OffsetValuePairT,
218             ReduceBySegmentOpT,
219             ScanTileStateT>
220         TilePrefixCallbackOpT;
221 
222     // Key and value exchange types
223     typedef KeyOutputT    KeyExchangeT[TILE_ITEMS + 1];
224     typedef ValueOutputT  ValueExchangeT[TILE_ITEMS + 1];
225 
226     // Shared memory type for this thread block
227     union _TempStorage
228     {
229         struct
230         {
231             typename BlockScanT::TempStorage                scan;           // Smem needed for tile scanning
232             typename TilePrefixCallbackOpT::TempStorage     prefix;         // Smem needed for cooperative prefix callback
233             typename BlockDiscontinuityKeys::TempStorage    discontinuity;  // Smem needed for discontinuity detection
234         };
235 
236         // Smem needed for loading keys
237         typename BlockLoadKeysT::TempStorage load_keys;
238 
239         // Smem needed for loading values
240         typename BlockLoadValuesT::TempStorage load_values;
241 
242         // Smem needed for compacting key value pairs(allows non POD items in this union)
243         Uninitialized<KeyValuePairT[TILE_ITEMS + 1]> raw_exchange;
244     };
245 
246     // Alias wrapper allowing storage to be unioned
247     struct TempStorage : Uninitialized<_TempStorage> {};
248 
249 
250     //---------------------------------------------------------------------
251     // Per-thread fields
252     //---------------------------------------------------------------------
253 
254     _TempStorage&                   temp_storage;       ///< Reference to temp_storage
255     WrappedKeysInputIteratorT       d_keys_in;          ///< Input keys
256     UniqueOutputIteratorT           d_unique_out;       ///< Unique output keys
257     WrappedValuesInputIteratorT     d_values_in;        ///< Input values
258     AggregatesOutputIteratorT       d_aggregates_out;   ///< Output value aggregates
259     NumRunsOutputIteratorT          d_num_runs_out;     ///< Output pointer for total number of segments identified
260     EqualityOpT                     equality_op;        ///< KeyT equality operator
261     ReductionOpT                    reduction_op;       ///< Reduction operator
262     ReduceBySegmentOpT              scan_op;            ///< Reduce-by-segment scan operator
263 
264 
265     //---------------------------------------------------------------------
266     // Constructor
267     //---------------------------------------------------------------------
268 
269     // Constructor
270     __device__ __forceinline__
AgentReduceByKeycub::AgentReduceByKey271     AgentReduceByKey(
272         TempStorage&                temp_storage,       ///< Reference to temp_storage
273         KeysInputIteratorT          d_keys_in,          ///< Input keys
274         UniqueOutputIteratorT       d_unique_out,       ///< Unique output keys
275         ValuesInputIteratorT        d_values_in,        ///< Input values
276         AggregatesOutputIteratorT   d_aggregates_out,   ///< Output value aggregates
277         NumRunsOutputIteratorT      d_num_runs_out,     ///< Output pointer for total number of segments identified
278         EqualityOpT                 equality_op,        ///< KeyT equality operator
279         ReductionOpT                reduction_op)       ///< ValueT reduction operator
280     :
281         temp_storage(temp_storage.Alias()),
282         d_keys_in(d_keys_in),
283         d_unique_out(d_unique_out),
284         d_values_in(d_values_in),
285         d_aggregates_out(d_aggregates_out),
286         d_num_runs_out(d_num_runs_out),
287         equality_op(equality_op),
288         reduction_op(reduction_op),
289         scan_op(reduction_op)
290     {}
291 
292 
293     //---------------------------------------------------------------------
294     // Scatter utility methods
295     //---------------------------------------------------------------------
296 
297     /**
298      * Directly scatter flagged items to output offsets
299      */
ScatterDirectcub::AgentReduceByKey300     __device__ __forceinline__ void ScatterDirect(
301         KeyValuePairT   (&scatter_items)[ITEMS_PER_THREAD],
302         OffsetT         (&segment_flags)[ITEMS_PER_THREAD],
303         OffsetT         (&segment_indices)[ITEMS_PER_THREAD])
304     {
305         // Scatter flagged keys and values
306         #pragma unroll
307         for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
308         {
309             if (segment_flags[ITEM])
310             {
311                 d_unique_out[segment_indices[ITEM]]     = scatter_items[ITEM].key;
312                 d_aggregates_out[segment_indices[ITEM]] = scatter_items[ITEM].value;
313             }
314         }
315     }
316 
317 
318     /**
319      * 2-phase scatter flagged items to output offsets
320      *
321      * The exclusive scan causes each head flag to be paired with the previous
322      * value aggregate: the scatter offsets must be decremented for value aggregates
323      */
ScatterTwoPhasecub::AgentReduceByKey324     __device__ __forceinline__ void ScatterTwoPhase(
325         KeyValuePairT   (&scatter_items)[ITEMS_PER_THREAD],
326         OffsetT         (&segment_flags)[ITEMS_PER_THREAD],
327         OffsetT         (&segment_indices)[ITEMS_PER_THREAD],
328         OffsetT         num_tile_segments,
329         OffsetT         num_tile_segments_prefix)
330     {
331         CTA_SYNC();
332 
333         // Compact and scatter pairs
334         #pragma unroll
335         for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
336         {
337             if (segment_flags[ITEM])
338             {
339                 temp_storage.raw_exchange.Alias()[segment_indices[ITEM] - num_tile_segments_prefix] = scatter_items[ITEM];
340             }
341         }
342 
343         CTA_SYNC();
344 
345         for (int item = threadIdx.x; item < num_tile_segments; item += BLOCK_THREADS)
346         {
347             KeyValuePairT pair                                  = temp_storage.raw_exchange.Alias()[item];
348             d_unique_out[num_tile_segments_prefix + item]       = pair.key;
349             d_aggregates_out[num_tile_segments_prefix + item]   = pair.value;
350         }
351     }
352 
353 
354     /**
355      * Scatter flagged items
356      */
Scattercub::AgentReduceByKey357     __device__ __forceinline__ void Scatter(
358         KeyValuePairT   (&scatter_items)[ITEMS_PER_THREAD],
359         OffsetT         (&segment_flags)[ITEMS_PER_THREAD],
360         OffsetT         (&segment_indices)[ITEMS_PER_THREAD],
361         OffsetT         num_tile_segments,
362         OffsetT         num_tile_segments_prefix)
363     {
364         // Do a one-phase scatter if (a) two-phase is disabled or (b) the average number of selected items per thread is less than one
365         if (TWO_PHASE_SCATTER && (num_tile_segments > BLOCK_THREADS))
366         {
367             ScatterTwoPhase(
368                 scatter_items,
369                 segment_flags,
370                 segment_indices,
371                 num_tile_segments,
372                 num_tile_segments_prefix);
373         }
374         else
375         {
376             ScatterDirect(
377                 scatter_items,
378                 segment_flags,
379                 segment_indices);
380         }
381     }
382 
383 
384     //---------------------------------------------------------------------
385     // Cooperatively scan a device-wide sequence of tiles with other CTAs
386     //---------------------------------------------------------------------
387 
388     /**
389      * Process a tile of input (dynamic chained scan)
390      */
391     template <bool IS_LAST_TILE>                ///< Whether the current tile is the last tile
ConsumeTilecub::AgentReduceByKey392     __device__ __forceinline__ void ConsumeTile(
393         OffsetT             num_remaining,      ///< Number of global input items remaining (including this tile)
394         int                 tile_idx,           ///< Tile index
395         OffsetT             tile_offset,        ///< Tile offset
396         ScanTileStateT&     tile_state)         ///< Global tile state descriptor
397     {
398         KeyOutputT          keys[ITEMS_PER_THREAD];             // Tile keys
399         KeyOutputT          prev_keys[ITEMS_PER_THREAD];        // Tile keys shuffled up
400         ValueOutputT        values[ITEMS_PER_THREAD];           // Tile values
401         OffsetT             head_flags[ITEMS_PER_THREAD];       // Segment head flags
402         OffsetT             segment_indices[ITEMS_PER_THREAD];  // Segment indices
403         OffsetValuePairT    scan_items[ITEMS_PER_THREAD];       // Zipped values and segment flags|indices
404         KeyValuePairT       scatter_items[ITEMS_PER_THREAD];    // Zipped key value pairs for scattering
405 
406         // Load keys
407         if (IS_LAST_TILE)
408             BlockLoadKeysT(temp_storage.load_keys).Load(d_keys_in + tile_offset, keys, num_remaining);
409         else
410             BlockLoadKeysT(temp_storage.load_keys).Load(d_keys_in + tile_offset, keys);
411 
412         // Load tile predecessor key in first thread
413         KeyOutputT tile_predecessor;
414         if (threadIdx.x == 0)
415         {
416             tile_predecessor = (tile_idx == 0) ?
417                 keys[0] :                       // First tile gets repeat of first item (thus first item will not be flagged as a head)
418                 d_keys_in[tile_offset - 1];     // Subsequent tiles get last key from previous tile
419         }
420 
421         CTA_SYNC();
422 
423         // Load values
424         if (IS_LAST_TILE)
425             BlockLoadValuesT(temp_storage.load_values).Load(d_values_in + tile_offset, values, num_remaining);
426         else
427             BlockLoadValuesT(temp_storage.load_values).Load(d_values_in + tile_offset, values);
428 
429         CTA_SYNC();
430 
431         // Initialize head-flags and shuffle up the previous keys
432         if (IS_LAST_TILE)
433         {
434             // Use custom flag operator to additionally flag the first out-of-bounds item
435             GuardedInequalityWrapper<EqualityOpT> flag_op(equality_op, num_remaining);
436             BlockDiscontinuityKeys(temp_storage.discontinuity).FlagHeads(
437                 head_flags, keys, prev_keys, flag_op, tile_predecessor);
438         }
439         else
440         {
441             InequalityWrapper<EqualityOpT> flag_op(equality_op);
442             BlockDiscontinuityKeys(temp_storage.discontinuity).FlagHeads(
443                 head_flags, keys, prev_keys, flag_op, tile_predecessor);
444         }
445 
446         // Zip values and head flags
447         #pragma unroll
448         for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
449         {
450             scan_items[ITEM].value  = values[ITEM];
451             scan_items[ITEM].key    = head_flags[ITEM];
452         }
453 
454         // Perform exclusive tile scan
455         OffsetValuePairT    block_aggregate;        // Inclusive block-wide scan aggregate
456         OffsetT             num_segments_prefix;    // Number of segments prior to this tile
457         OffsetValuePairT    total_aggregate;        // The tile prefix folded with block_aggregate
458         if (tile_idx == 0)
459         {
460             // Scan first tile
461             BlockScanT(temp_storage.scan).ExclusiveScan(scan_items, scan_items, scan_op, block_aggregate);
462             num_segments_prefix     = 0;
463             total_aggregate         = block_aggregate;
464 
465             // Update tile status if there are successor tiles
466             if ((!IS_LAST_TILE) && (threadIdx.x == 0))
467                 tile_state.SetInclusive(0, block_aggregate);
468         }
469         else
470         {
471             // Scan non-first tile
472             TilePrefixCallbackOpT prefix_op(tile_state, temp_storage.prefix, scan_op, tile_idx);
473             BlockScanT(temp_storage.scan).ExclusiveScan(scan_items, scan_items, scan_op, prefix_op);
474 
475             block_aggregate         = prefix_op.GetBlockAggregate();
476             num_segments_prefix     = prefix_op.GetExclusivePrefix().key;
477             total_aggregate         = prefix_op.GetInclusivePrefix();
478         }
479 
480         // Rezip scatter items and segment indices
481         #pragma unroll
482         for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
483         {
484             scatter_items[ITEM].key     = prev_keys[ITEM];
485             scatter_items[ITEM].value   = scan_items[ITEM].value;
486             segment_indices[ITEM]       = scan_items[ITEM].key;
487         }
488 
489         // At this point, each flagged segment head has:
490         //  - The key for the previous segment
491         //  - The reduced value from the previous segment
492         //  - The segment index for the reduced value
493 
494         // Scatter flagged keys and values
495         OffsetT num_tile_segments = block_aggregate.key;
496         Scatter(scatter_items, head_flags, segment_indices, num_tile_segments, num_segments_prefix);
497 
498         // Last thread in last tile will output final count (and last pair, if necessary)
499         if ((IS_LAST_TILE) && (threadIdx.x == BLOCK_THREADS - 1))
500         {
501             OffsetT num_segments = num_segments_prefix + num_tile_segments;
502 
503             // If the last tile is a whole tile, output the final_value
504             if (num_remaining == TILE_ITEMS)
505             {
506                 d_unique_out[num_segments]      = keys[ITEMS_PER_THREAD - 1];
507                 d_aggregates_out[num_segments]  = total_aggregate.value;
508                 num_segments++;
509             }
510 
511             // Output the total number of items selected
512             *d_num_runs_out = num_segments;
513         }
514     }
515 
516 
517     /**
518      * Scan tiles of items as part of a dynamic chained scan
519      */
ConsumeRangecub::AgentReduceByKey520     __device__ __forceinline__ void ConsumeRange(
521         int                 num_items,          ///< Total number of input items
522         ScanTileStateT&     tile_state,         ///< Global tile state descriptor
523         int                 start_tile)         ///< The starting tile for the current grid
524     {
525         // Blocks are launched in increasing order, so just assign one tile per block
526         int     tile_idx        = start_tile + blockIdx.x;          // Current tile index
527         OffsetT tile_offset     = OffsetT(TILE_ITEMS) * tile_idx;   // Global offset for the current tile
528         OffsetT num_remaining   = num_items - tile_offset;          // Remaining items (including this tile)
529 
530         if (num_remaining > TILE_ITEMS)
531         {
532             // Not last tile
533             ConsumeTile<false>(num_remaining, tile_idx, tile_offset, tile_state);
534         }
535         else if (num_remaining > 0)
536         {
537             // Last tile
538             ConsumeTile<true>(num_remaining, tile_idx, tile_offset, tile_state);
539         }
540     }
541 
542 };
543 
544 
545 }               // CUB namespace
546 CUB_NS_POSTFIX  // Optional outer namespace(s)
547 
548