1 /******************************************************************************
2  * Copyright (c) 2011, Duane Merrill.  All rights reserved.
3  * Copyright (c) 2011-2018, NVIDIA CORPORATION.  All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are met:
7  *     * Redistributions of source code must retain the above copyright
8  *       notice, this list of conditions and the following disclaimer.
9  *     * Redistributions in binary form must reproduce the above copyright
10  *       notice, this list of conditions and the following disclaimer in the
11  *       documentation and/or other materials provided with the distribution.
12  *     * Neither the name of the NVIDIA CORPORATION nor the
13  *       names of its contributors may be used to endorse or promote products
14  *       derived from this software without specific prior written permission.
15  *
16  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
17  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19  * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
20  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26  *
27  ******************************************************************************/
28 
29 /**
30  * \file
31  * cub::AgentSelectIf implements a stateful abstraction of CUDA thread blocks for participating in device-wide select.
32  */
33 
34 #pragma once
35 
36 #include <iterator>
37 
38 #include "single_pass_scan_operators.cuh"
39 #include "../block/block_load.cuh"
40 #include "../block/block_store.cuh"
41 #include "../block/block_scan.cuh"
42 #include "../block/block_exchange.cuh"
43 #include "../block/block_discontinuity.cuh"
44 #include "../grid/grid_queue.cuh"
45 #include "../iterator/cache_modified_input_iterator.cuh"
46 #include "../util_namespace.cuh"
47 
48 /// Optional outer namespace(s)
49 CUB_NS_PREFIX
50 
51 /// CUB namespace
52 namespace cub {
53 
54 
55 /******************************************************************************
56  * Tuning policy types
57  ******************************************************************************/
58 
59 /**
60  * Parameterizable tuning policy type for AgentSelectIf
61  */
62 template <
63     int                         _BLOCK_THREADS,                 ///< Threads per thread block
64     int                         _ITEMS_PER_THREAD,              ///< Items per thread (per tile of input)
65     BlockLoadAlgorithm          _LOAD_ALGORITHM,                ///< The BlockLoad algorithm to use
66     CacheLoadModifier           _LOAD_MODIFIER,                 ///< Cache load modifier for reading input elements
67     BlockScanAlgorithm          _SCAN_ALGORITHM>                ///< The BlockScan algorithm to use
68 struct AgentSelectIfPolicy
69 {
70     enum
71     {
72         BLOCK_THREADS           = _BLOCK_THREADS,               ///< Threads per thread block
73         ITEMS_PER_THREAD        = _ITEMS_PER_THREAD,            ///< Items per thread (per tile of input)
74     };
75 
76     static const BlockLoadAlgorithm     LOAD_ALGORITHM          = _LOAD_ALGORITHM;      ///< The BlockLoad algorithm to use
77     static const CacheLoadModifier      LOAD_MODIFIER           = _LOAD_MODIFIER;       ///< Cache load modifier for reading input elements
78     static const BlockScanAlgorithm     SCAN_ALGORITHM          = _SCAN_ALGORITHM;      ///< The BlockScan algorithm to use
79 };
80 
81 
82 
83 
84 /******************************************************************************
85  * Thread block abstractions
86  ******************************************************************************/
87 
88 
89 /**
90  * \brief AgentSelectIf implements a stateful abstraction of CUDA thread blocks for participating in device-wide selection
91  *
92  * Performs functor-based selection if SelectOpT functor type != NullType
93  * Otherwise performs flag-based selection if FlagsInputIterator's value type != NullType
94  * Otherwise performs discontinuity selection (keep unique)
95  */
96 template <
97     typename    AgentSelectIfPolicyT,           ///< Parameterized AgentSelectIfPolicy tuning policy type
98     typename    InputIteratorT,                 ///< Random-access input iterator type for selection items
99     typename    FlagsInputIteratorT,            ///< Random-access input iterator type for selections (NullType* if a selection functor or discontinuity flagging is to be used for selection)
100     typename    SelectedOutputIteratorT,        ///< Random-access input iterator type for selection_flags items
101     typename    SelectOpT,                      ///< Selection operator type (NullType if selections or discontinuity flagging is to be used for selection)
102     typename    EqualityOpT,                    ///< Equality operator type (NullType if selection functor or selections is to be used for selection)
103     typename    OffsetT,                        ///< Signed integer type for global offsets
104     bool        KEEP_REJECTS>                   ///< Whether or not we push rejected items to the back of the output
105 struct AgentSelectIf
106 {
107     //---------------------------------------------------------------------
108     // Types and constants
109     //---------------------------------------------------------------------
110 
111     // The input value type
112     typedef typename std::iterator_traits<InputIteratorT>::value_type InputT;
113 
114     // The output value type
115     typedef typename If<(Equals<typename std::iterator_traits<SelectedOutputIteratorT>::value_type, void>::VALUE),  // OutputT =  (if output iterator's value type is void) ?
116         typename std::iterator_traits<InputIteratorT>::value_type,                                                  // ... then the input iterator's value type,
117         typename std::iterator_traits<SelectedOutputIteratorT>::value_type>::Type OutputT;                          // ... else the output iterator's value type
118 
119     // The flag value type
120     typedef typename std::iterator_traits<FlagsInputIteratorT>::value_type FlagT;
121 
122     // Tile status descriptor interface type
123     typedef ScanTileState<OffsetT> ScanTileStateT;
124 
125     // Constants
126     enum
127     {
128         USE_SELECT_OP,
129         USE_SELECT_FLAGS,
130         USE_DISCONTINUITY,
131 
132         BLOCK_THREADS           = AgentSelectIfPolicyT::BLOCK_THREADS,
133         ITEMS_PER_THREAD        = AgentSelectIfPolicyT::ITEMS_PER_THREAD,
134         TILE_ITEMS              = BLOCK_THREADS * ITEMS_PER_THREAD,
135         TWO_PHASE_SCATTER       = (ITEMS_PER_THREAD > 1),
136 
137         SELECT_METHOD           = (!Equals<SelectOpT, NullType>::VALUE) ?
138                                     USE_SELECT_OP :
139                                     (!Equals<FlagT, NullType>::VALUE) ?
140                                         USE_SELECT_FLAGS :
141                                         USE_DISCONTINUITY
142     };
143 
144     // Cache-modified Input iterator wrapper type (for applying cache modifier) for items
145     typedef typename If<IsPointer<InputIteratorT>::VALUE,
146             CacheModifiedInputIterator<AgentSelectIfPolicyT::LOAD_MODIFIER, InputT, OffsetT>,        // Wrap the native input pointer with CacheModifiedValuesInputIterator
147             InputIteratorT>::Type                                                               // Directly use the supplied input iterator type
148         WrappedInputIteratorT;
149 
150     // Cache-modified Input iterator wrapper type (for applying cache modifier) for values
151     typedef typename If<IsPointer<FlagsInputIteratorT>::VALUE,
152             CacheModifiedInputIterator<AgentSelectIfPolicyT::LOAD_MODIFIER, FlagT, OffsetT>,    // Wrap the native input pointer with CacheModifiedValuesInputIterator
153             FlagsInputIteratorT>::Type                                                          // Directly use the supplied input iterator type
154         WrappedFlagsInputIteratorT;
155 
156     // Parameterized BlockLoad type for input data
157     typedef BlockLoad<
158             OutputT,
159             BLOCK_THREADS,
160             ITEMS_PER_THREAD,
161             AgentSelectIfPolicyT::LOAD_ALGORITHM>
162         BlockLoadT;
163 
164     // Parameterized BlockLoad type for flags
165     typedef BlockLoad<
166             FlagT,
167             BLOCK_THREADS,
168             ITEMS_PER_THREAD,
169             AgentSelectIfPolicyT::LOAD_ALGORITHM>
170         BlockLoadFlags;
171 
172     // Parameterized BlockDiscontinuity type for items
173     typedef BlockDiscontinuity<
174             OutputT,
175             BLOCK_THREADS>
176         BlockDiscontinuityT;
177 
178     // Parameterized BlockScan type
179     typedef BlockScan<
180             OffsetT,
181             BLOCK_THREADS,
182             AgentSelectIfPolicyT::SCAN_ALGORITHM>
183         BlockScanT;
184 
185     // Callback type for obtaining tile prefix during block scan
186     typedef TilePrefixCallbackOp<
187             OffsetT,
188             cub::Sum,
189             ScanTileStateT>
190         TilePrefixCallbackOpT;
191 
192     // Item exchange type
193     typedef OutputT ItemExchangeT[TILE_ITEMS];
194 
195     // Shared memory type for this thread block
196     union _TempStorage
197     {
198         struct
199         {
200             typename BlockScanT::TempStorage                scan;           // Smem needed for tile scanning
201             typename TilePrefixCallbackOpT::TempStorage     prefix;         // Smem needed for cooperative prefix callback
202             typename BlockDiscontinuityT::TempStorage       discontinuity;  // Smem needed for discontinuity detection
203         };
204 
205         // Smem needed for loading items
206         typename BlockLoadT::TempStorage load_items;
207 
208         // Smem needed for loading values
209         typename BlockLoadFlags::TempStorage load_flags;
210 
211         // Smem needed for compacting items (allows non POD items in this union)
212         Uninitialized<ItemExchangeT> raw_exchange;
213     };
214 
215     // Alias wrapper allowing storage to be unioned
216     struct TempStorage : Uninitialized<_TempStorage> {};
217 
218 
219     //---------------------------------------------------------------------
220     // Per-thread fields
221     //---------------------------------------------------------------------
222 
223     _TempStorage&                   temp_storage;       ///< Reference to temp_storage
224     WrappedInputIteratorT           d_in;               ///< Input items
225     SelectedOutputIteratorT         d_selected_out;     ///< Unique output items
226     WrappedFlagsInputIteratorT      d_flags_in;         ///< Input selection flags (if applicable)
227     InequalityWrapper<EqualityOpT>  inequality_op;      ///< T inequality operator
228     SelectOpT                       select_op;          ///< Selection operator
229     OffsetT                         num_items;          ///< Total number of input items
230 
231 
232     //---------------------------------------------------------------------
233     // Constructor
234     //---------------------------------------------------------------------
235 
236     // Constructor
237     __device__ __forceinline__
AgentSelectIfcub::AgentSelectIf238     AgentSelectIf(
239         TempStorage                 &temp_storage,      ///< Reference to temp_storage
240         InputIteratorT              d_in,               ///< Input data
241         FlagsInputIteratorT         d_flags_in,         ///< Input selection flags (if applicable)
242         SelectedOutputIteratorT     d_selected_out,     ///< Output data
243         SelectOpT                   select_op,          ///< Selection operator
244         EqualityOpT                 equality_op,        ///< Equality operator
245         OffsetT                     num_items)          ///< Total number of input items
246     :
247         temp_storage(temp_storage.Alias()),
248         d_in(d_in),
249         d_flags_in(d_flags_in),
250         d_selected_out(d_selected_out),
251         select_op(select_op),
252         inequality_op(equality_op),
253         num_items(num_items)
254     {}
255 
256 
257     //---------------------------------------------------------------------
258     // Utility methods for initializing the selections
259     //---------------------------------------------------------------------
260 
261     /**
262      * Initialize selections (specialized for selection operator)
263      */
264     template <bool IS_FIRST_TILE, bool IS_LAST_TILE>
InitializeSelectionscub::AgentSelectIf265     __device__ __forceinline__ void InitializeSelections(
266         OffsetT                     /*tile_offset*/,
267         OffsetT                     num_tile_items,
268         OutputT                     (&items)[ITEMS_PER_THREAD],
269         OffsetT                     (&selection_flags)[ITEMS_PER_THREAD],
270         Int2Type<USE_SELECT_OP>     /*select_method*/)
271     {
272         #pragma unroll
273         for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
274         {
275             // Out-of-bounds items are selection_flags
276             selection_flags[ITEM] = 1;
277 
278             if (!IS_LAST_TILE || (OffsetT(threadIdx.x * ITEMS_PER_THREAD) + ITEM < num_tile_items))
279                 selection_flags[ITEM] = select_op(items[ITEM]);
280         }
281     }
282 
283 
284     /**
285      * Initialize selections (specialized for valid flags)
286      */
287     template <bool IS_FIRST_TILE, bool IS_LAST_TILE>
InitializeSelectionscub::AgentSelectIf288     __device__ __forceinline__ void InitializeSelections(
289         OffsetT                     tile_offset,
290         OffsetT                     num_tile_items,
291         OutputT                     (&/*items*/)[ITEMS_PER_THREAD],
292         OffsetT                     (&selection_flags)[ITEMS_PER_THREAD],
293         Int2Type<USE_SELECT_FLAGS>  /*select_method*/)
294     {
295         CTA_SYNC();
296 
297         FlagT flags[ITEMS_PER_THREAD];
298 
299         if (IS_LAST_TILE)
300         {
301             // Out-of-bounds items are selection_flags
302             BlockLoadFlags(temp_storage.load_flags).Load(d_flags_in + tile_offset, flags, num_tile_items, 1);
303         }
304         else
305         {
306             BlockLoadFlags(temp_storage.load_flags).Load(d_flags_in + tile_offset, flags);
307         }
308 
309         // Convert flag type to selection_flags type
310         #pragma unroll
311         for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
312         {
313             selection_flags[ITEM] = flags[ITEM];
314         }
315     }
316 
317 
318     /**
319      * Initialize selections (specialized for discontinuity detection)
320      */
321     template <bool IS_FIRST_TILE, bool IS_LAST_TILE>
InitializeSelectionscub::AgentSelectIf322     __device__ __forceinline__ void InitializeSelections(
323         OffsetT                     tile_offset,
324         OffsetT                     num_tile_items,
325         OutputT                     (&items)[ITEMS_PER_THREAD],
326         OffsetT                     (&selection_flags)[ITEMS_PER_THREAD],
327         Int2Type<USE_DISCONTINUITY> /*select_method*/)
328     {
329         if (IS_FIRST_TILE)
330         {
331             CTA_SYNC();
332 
333             // Set head selection_flags.  First tile sets the first flag for the first item
334             BlockDiscontinuityT(temp_storage.discontinuity).FlagHeads(selection_flags, items, inequality_op);
335         }
336         else
337         {
338             OutputT tile_predecessor;
339             if (threadIdx.x == 0)
340                 tile_predecessor = d_in[tile_offset - 1];
341 
342             CTA_SYNC();
343 
344             BlockDiscontinuityT(temp_storage.discontinuity).FlagHeads(selection_flags, items, inequality_op, tile_predecessor);
345         }
346 
347         // Set selection flags for out-of-bounds items
348         #pragma unroll
349         for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
350         {
351             // Set selection_flags for out-of-bounds items
352             if ((IS_LAST_TILE) && (OffsetT(threadIdx.x * ITEMS_PER_THREAD) + ITEM >= num_tile_items))
353                 selection_flags[ITEM] = 1;
354         }
355     }
356 
357 
358     //---------------------------------------------------------------------
359     // Scatter utility methods
360     //---------------------------------------------------------------------
361 
362     /**
363      * Scatter flagged items to output offsets (specialized for direct scattering)
364      */
365     template <bool IS_LAST_TILE, bool IS_FIRST_TILE>
ScatterDirectcub::AgentSelectIf366     __device__ __forceinline__ void ScatterDirect(
367         OutputT (&items)[ITEMS_PER_THREAD],
368         OffsetT (&selection_flags)[ITEMS_PER_THREAD],
369         OffsetT (&selection_indices)[ITEMS_PER_THREAD],
370         OffsetT num_selections)
371     {
372         // Scatter flagged items
373         #pragma unroll
374         for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
375         {
376             if (selection_flags[ITEM])
377             {
378                 if ((!IS_LAST_TILE) || selection_indices[ITEM] < num_selections)
379                 {
380                     d_selected_out[selection_indices[ITEM]] = items[ITEM];
381                 }
382             }
383         }
384     }
385 
386 
387     /**
388      * Scatter flagged items to output offsets (specialized for two-phase scattering)
389      */
390     template <bool IS_LAST_TILE, bool IS_FIRST_TILE>
ScatterTwoPhasecub::AgentSelectIf391     __device__ __forceinline__ void ScatterTwoPhase(
392         OutputT         (&items)[ITEMS_PER_THREAD],
393         OffsetT         (&selection_flags)[ITEMS_PER_THREAD],
394         OffsetT         (&selection_indices)[ITEMS_PER_THREAD],
395         int             /*num_tile_items*/,                         ///< Number of valid items in this tile
396         int             num_tile_selections,                        ///< Number of selections in this tile
397         OffsetT         num_selections_prefix,                      ///< Total number of selections prior to this tile
398         OffsetT         /*num_rejected_prefix*/,                    ///< Total number of rejections prior to this tile
399         Int2Type<false> /*is_keep_rejects*/)                        ///< Marker type indicating whether to keep rejected items in the second partition
400     {
401         CTA_SYNC();
402 
403         // Compact and scatter items
404         #pragma unroll
405         for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
406         {
407             int local_scatter_offset = selection_indices[ITEM] - num_selections_prefix;
408             if (selection_flags[ITEM])
409             {
410                 temp_storage.raw_exchange.Alias()[local_scatter_offset] = items[ITEM];
411             }
412         }
413 
414         CTA_SYNC();
415 
416         for (int item = threadIdx.x; item < num_tile_selections; item += BLOCK_THREADS)
417         {
418             d_selected_out[num_selections_prefix + item] = temp_storage.raw_exchange.Alias()[item];
419         }
420     }
421 
422 
423     /**
424      * Scatter flagged items to output offsets (specialized for two-phase scattering)
425      */
426     template <bool IS_LAST_TILE, bool IS_FIRST_TILE>
ScatterTwoPhasecub::AgentSelectIf427     __device__ __forceinline__ void ScatterTwoPhase(
428         OutputT         (&items)[ITEMS_PER_THREAD],
429         OffsetT         (&selection_flags)[ITEMS_PER_THREAD],
430         OffsetT         (&selection_indices)[ITEMS_PER_THREAD],
431         int             num_tile_items,                             ///< Number of valid items in this tile
432         int             num_tile_selections,                        ///< Number of selections in this tile
433         OffsetT         num_selections_prefix,                      ///< Total number of selections prior to this tile
434         OffsetT         num_rejected_prefix,                        ///< Total number of rejections prior to this tile
435         Int2Type<true>  /*is_keep_rejects*/)                        ///< Marker type indicating whether to keep rejected items in the second partition
436     {
437         CTA_SYNC();
438 
439         int tile_num_rejections = num_tile_items - num_tile_selections;
440 
441         // Scatter items to shared memory (rejections first)
442         #pragma unroll
443         for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
444         {
445             int item_idx                = (threadIdx.x * ITEMS_PER_THREAD) + ITEM;
446             int local_selection_idx     = selection_indices[ITEM] - num_selections_prefix;
447             int local_rejection_idx     = item_idx - local_selection_idx;
448             int local_scatter_offset    = (selection_flags[ITEM]) ?
449                                             tile_num_rejections + local_selection_idx :
450                                             local_rejection_idx;
451 
452             temp_storage.raw_exchange.Alias()[local_scatter_offset] = items[ITEM];
453         }
454 
455         CTA_SYNC();
456 
457         // Gather items from shared memory and scatter to global
458         #pragma unroll
459         for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
460         {
461             int item_idx            = (ITEM * BLOCK_THREADS) + threadIdx.x;
462             int rejection_idx       = item_idx;
463             int selection_idx       = item_idx - tile_num_rejections;
464             OffsetT scatter_offset  = (item_idx < tile_num_rejections) ?
465                                         num_items - num_rejected_prefix - rejection_idx - 1 :
466                                         num_selections_prefix + selection_idx;
467 
468             OutputT item = temp_storage.raw_exchange.Alias()[item_idx];
469 
470             if (!IS_LAST_TILE || (item_idx < num_tile_items))
471             {
472                 d_selected_out[scatter_offset] = item;
473             }
474         }
475     }
476 
477 
478     /**
479      * Scatter flagged items
480      */
481     template <bool IS_LAST_TILE, bool IS_FIRST_TILE>
Scattercub::AgentSelectIf482     __device__ __forceinline__ void Scatter(
483         OutputT         (&items)[ITEMS_PER_THREAD],
484         OffsetT         (&selection_flags)[ITEMS_PER_THREAD],
485         OffsetT         (&selection_indices)[ITEMS_PER_THREAD],
486         int             num_tile_items,                             ///< Number of valid items in this tile
487         int             num_tile_selections,                        ///< Number of selections in this tile
488         OffsetT         num_selections_prefix,                      ///< Total number of selections prior to this tile
489         OffsetT         num_rejected_prefix,                        ///< Total number of rejections prior to this tile
490         OffsetT         num_selections)                             ///< Total number of selections including this tile
491     {
492         // Do a two-phase scatter if (a) keeping both partitions or (b) two-phase is enabled and the average number of selection_flags items per thread is greater than one
493         if (KEEP_REJECTS || (TWO_PHASE_SCATTER && (num_tile_selections > BLOCK_THREADS)))
494         {
495             ScatterTwoPhase<IS_LAST_TILE, IS_FIRST_TILE>(
496                 items,
497                 selection_flags,
498                 selection_indices,
499                 num_tile_items,
500                 num_tile_selections,
501                 num_selections_prefix,
502                 num_rejected_prefix,
503                 Int2Type<KEEP_REJECTS>());
504         }
505         else
506         {
507             ScatterDirect<IS_LAST_TILE, IS_FIRST_TILE>(
508                 items,
509                 selection_flags,
510                 selection_indices,
511                 num_selections);
512         }
513     }
514 
515     //---------------------------------------------------------------------
516     // Cooperatively scan a device-wide sequence of tiles with other CTAs
517     //---------------------------------------------------------------------
518 
519 
520     /**
521      * Process first tile of input (dynamic chained scan).  Returns the running count of selections (including this tile)
522      */
523     template <bool IS_LAST_TILE>
ConsumeFirstTilecub::AgentSelectIf524     __device__ __forceinline__ OffsetT ConsumeFirstTile(
525         int                 num_tile_items,      ///< Number of input items comprising this tile
526         OffsetT             tile_offset,        ///< Tile offset
527         ScanTileStateT&     tile_state)         ///< Global tile state descriptor
528     {
529         OutputT     items[ITEMS_PER_THREAD];
530         OffsetT     selection_flags[ITEMS_PER_THREAD];
531         OffsetT     selection_indices[ITEMS_PER_THREAD];
532 
533         // Load items
534         if (IS_LAST_TILE)
535             BlockLoadT(temp_storage.load_items).Load(d_in + tile_offset, items, num_tile_items);
536         else
537             BlockLoadT(temp_storage.load_items).Load(d_in + tile_offset, items);
538 
539         // Initialize selection_flags
540         InitializeSelections<true, IS_LAST_TILE>(
541             tile_offset,
542             num_tile_items,
543             items,
544             selection_flags,
545             Int2Type<SELECT_METHOD>());
546 
547         CTA_SYNC();
548 
549         // Exclusive scan of selection_flags
550         OffsetT num_tile_selections;
551         BlockScanT(temp_storage.scan).ExclusiveSum(selection_flags, selection_indices, num_tile_selections);
552 
553         if (threadIdx.x == 0)
554         {
555             // Update tile status if this is not the last tile
556             if (!IS_LAST_TILE)
557                 tile_state.SetInclusive(0, num_tile_selections);
558         }
559 
560         // Discount any out-of-bounds selections
561         if (IS_LAST_TILE)
562             num_tile_selections -= (TILE_ITEMS - num_tile_items);
563 
564         // Scatter flagged items
565         Scatter<IS_LAST_TILE, true>(
566             items,
567             selection_flags,
568             selection_indices,
569             num_tile_items,
570             num_tile_selections,
571             0,
572             0,
573             num_tile_selections);
574 
575         return num_tile_selections;
576     }
577 
578 
579     /**
580      * Process subsequent tile of input (dynamic chained scan).  Returns the running count of selections (including this tile)
581      */
582     template <bool IS_LAST_TILE>
ConsumeSubsequentTilecub::AgentSelectIf583     __device__ __forceinline__ OffsetT ConsumeSubsequentTile(
584         int                 num_tile_items,      ///< Number of input items comprising this tile
585         int                 tile_idx,           ///< Tile index
586         OffsetT             tile_offset,        ///< Tile offset
587         ScanTileStateT&     tile_state)         ///< Global tile state descriptor
588     {
589         OutputT     items[ITEMS_PER_THREAD];
590         OffsetT     selection_flags[ITEMS_PER_THREAD];
591         OffsetT     selection_indices[ITEMS_PER_THREAD];
592 
593         // Load items
594         if (IS_LAST_TILE)
595             BlockLoadT(temp_storage.load_items).Load(d_in + tile_offset, items, num_tile_items);
596         else
597             BlockLoadT(temp_storage.load_items).Load(d_in + tile_offset, items);
598 
599         // Initialize selection_flags
600         InitializeSelections<false, IS_LAST_TILE>(
601             tile_offset,
602             num_tile_items,
603             items,
604             selection_flags,
605             Int2Type<SELECT_METHOD>());
606 
607         CTA_SYNC();
608 
609         // Exclusive scan of values and selection_flags
610         TilePrefixCallbackOpT prefix_op(tile_state, temp_storage.prefix, cub::Sum(), tile_idx);
611         BlockScanT(temp_storage.scan).ExclusiveSum(selection_flags, selection_indices, prefix_op);
612 
613         OffsetT num_tile_selections     = prefix_op.GetBlockAggregate();
614         OffsetT num_selections          = prefix_op.GetInclusivePrefix();
615         OffsetT num_selections_prefix   = prefix_op.GetExclusivePrefix();
616         OffsetT num_rejected_prefix     = (tile_idx * TILE_ITEMS) - num_selections_prefix;
617 
618         // Discount any out-of-bounds selections
619         if (IS_LAST_TILE)
620         {
621             int num_discount    = TILE_ITEMS - num_tile_items;
622             num_selections      -= num_discount;
623             num_tile_selections -= num_discount;
624         }
625 
626         // Scatter flagged items
627         Scatter<IS_LAST_TILE, false>(
628             items,
629             selection_flags,
630             selection_indices,
631             num_tile_items,
632             num_tile_selections,
633             num_selections_prefix,
634             num_rejected_prefix,
635             num_selections);
636 
637         return num_selections;
638     }
639 
640 
641     /**
642      * Process a tile of input
643      */
644     template <bool IS_LAST_TILE>
ConsumeTilecub::AgentSelectIf645     __device__ __forceinline__ OffsetT ConsumeTile(
646         int                 num_tile_items,         ///< Number of input items comprising this tile
647         int                 tile_idx,           ///< Tile index
648         OffsetT             tile_offset,        ///< Tile offset
649         ScanTileStateT&     tile_state)         ///< Global tile state descriptor
650     {
651         OffsetT num_selections;
652         if (tile_idx == 0)
653         {
654             num_selections = ConsumeFirstTile<IS_LAST_TILE>(num_tile_items, tile_offset, tile_state);
655         }
656         else
657         {
658             num_selections = ConsumeSubsequentTile<IS_LAST_TILE>(num_tile_items, tile_idx, tile_offset, tile_state);
659         }
660 
661         return num_selections;
662     }
663 
664 
665     /**
666      * Scan tiles of items as part of a dynamic chained scan
667      */
668     template <typename NumSelectedIteratorT>        ///< Output iterator type for recording number of items selection_flags
ConsumeRangecub::AgentSelectIf669     __device__ __forceinline__ void ConsumeRange(
670         int                     num_tiles,          ///< Total number of input tiles
671         ScanTileStateT&         tile_state,         ///< Global tile state descriptor
672         NumSelectedIteratorT    d_num_selected_out) ///< Output total number selection_flags
673     {
674         // Blocks are launched in increasing order, so just assign one tile per block
675         int     tile_idx        = (blockIdx.x * gridDim.y) + blockIdx.y;    // Current tile index
676         OffsetT tile_offset     = tile_idx * TILE_ITEMS;                    // Global offset for the current tile
677 
678         if (tile_idx < num_tiles - 1)
679         {
680             // Not the last tile (full)
681             ConsumeTile<false>(TILE_ITEMS, tile_idx, tile_offset, tile_state);
682         }
683         else
684         {
685             // The last tile (possibly partially-full)
686             OffsetT num_remaining   = num_items - tile_offset;
687             OffsetT num_selections  = ConsumeTile<true>(num_remaining, tile_idx, tile_offset, tile_state);
688 
689             if (threadIdx.x == 0)
690             {
691                 // Output the total number of items selection_flags
692                 *d_num_selected_out = num_selections;
693             }
694         }
695     }
696 
697 };
698 
699 
700 
701 }               // CUB namespace
702 CUB_NS_POSTFIX  // Optional outer namespace(s)
703 
704