1 /******************************************************************************
2  * Copyright (c) 2011, Duane Merrill.  All rights reserved.
3  * Copyright (c) 2011-2018, NVIDIA CORPORATION.  All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are met:
7  *     * Redistributions of source code must retain the above copyright
8  *       notice, this list of conditions and the following disclaimer.
9  *     * Redistributions in binary form must reproduce the above copyright
10  *       notice, this list of conditions and the following disclaimer in the
11  *       documentation and/or other materials provided with the distribution.
12  *     * Neither the name of the NVIDIA CORPORATION nor the
13  *       names of its contributors may be used to endorse or promote products
14  *       derived from this software without specific prior written permission.
15  *
16  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
17  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19  * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
20  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26  *
27  ******************************************************************************/
28 
29 /**
30  * \file
31  * cub::AgentScan implements a stateful abstraction of CUDA thread blocks for participating in device-wide prefix scan .
32  */
33 
34 #pragma once
35 
36 #include <iterator>
37 
38 #include "single_pass_scan_operators.cuh"
39 #include "../block/block_load.cuh"
40 #include "../block/block_store.cuh"
41 #include "../block/block_scan.cuh"
42 #include "../grid/grid_queue.cuh"
43 #include "../iterator/cache_modified_input_iterator.cuh"
44 #include "../util_namespace.cuh"
45 
46 /// Optional outer namespace(s)
47 CUB_NS_PREFIX
48 
49 /// CUB namespace
50 namespace cub {
51 
52 
53 /******************************************************************************
54  * Tuning policy types
55  ******************************************************************************/
56 
57 /**
58  * Parameterizable tuning policy type for AgentScan
59  */
60 template <
61     int                         _BLOCK_THREADS,                 ///< Threads per thread block
62     int                         _ITEMS_PER_THREAD,              ///< Items per thread (per tile of input)
63     BlockLoadAlgorithm          _LOAD_ALGORITHM,                ///< The BlockLoad algorithm to use
64     CacheLoadModifier           _LOAD_MODIFIER,                 ///< Cache load modifier for reading input elements
65     BlockStoreAlgorithm         _STORE_ALGORITHM,               ///< The BlockStore algorithm to use
66     BlockScanAlgorithm          _SCAN_ALGORITHM>                ///< The BlockScan algorithm to use
67 struct AgentScanPolicy
68 {
69     enum
70     {
71         BLOCK_THREADS           = _BLOCK_THREADS,               ///< Threads per thread block
72         ITEMS_PER_THREAD        = _ITEMS_PER_THREAD,            ///< Items per thread (per tile of input)
73     };
74 
75     static const BlockLoadAlgorithm     LOAD_ALGORITHM          = _LOAD_ALGORITHM;          ///< The BlockLoad algorithm to use
76     static const CacheLoadModifier      LOAD_MODIFIER           = _LOAD_MODIFIER;           ///< Cache load modifier for reading input elements
77     static const BlockStoreAlgorithm    STORE_ALGORITHM         = _STORE_ALGORITHM;         ///< The BlockStore algorithm to use
78     static const BlockScanAlgorithm     SCAN_ALGORITHM          = _SCAN_ALGORITHM;          ///< The BlockScan algorithm to use
79 };
80 
81 
82 
83 
84 /******************************************************************************
85  * Thread block abstractions
86  ******************************************************************************/
87 
88 /**
89  * \brief AgentScan implements a stateful abstraction of CUDA thread blocks for participating in device-wide prefix scan .
90  */
91 template <
92     typename AgentScanPolicyT,      ///< Parameterized AgentScanPolicyT tuning policy type
93     typename InputIteratorT,        ///< Random-access input iterator type
94     typename OutputIteratorT,       ///< Random-access output iterator type
95     typename ScanOpT,               ///< Scan functor type
96     typename InitValueT,            ///< The init_value element for ScanOpT type (cub::NullType for inclusive scan)
97     typename OffsetT>               ///< Signed integer type for global offsets
98 struct AgentScan
99 {
100     //---------------------------------------------------------------------
101     // Types and constants
102     //---------------------------------------------------------------------
103 
104     // The input value type
105     typedef typename std::iterator_traits<InputIteratorT>::value_type InputT;
106 
107     // The output value type
108     typedef typename If<(Equals<typename std::iterator_traits<OutputIteratorT>::value_type, void>::VALUE),  // OutputT =  (if output iterator's value type is void) ?
109         typename std::iterator_traits<InputIteratorT>::value_type,                                          // ... then the input iterator's value type,
110         typename std::iterator_traits<OutputIteratorT>::value_type>::Type OutputT;                          // ... else the output iterator's value type
111 
112     // Tile status descriptor interface type
113     typedef ScanTileState<OutputT> ScanTileStateT;
114 
115     // Input iterator wrapper type (for applying cache modifier)
116     typedef typename If<IsPointer<InputIteratorT>::VALUE,
117             CacheModifiedInputIterator<AgentScanPolicyT::LOAD_MODIFIER, InputT, OffsetT>,   // Wrap the native input pointer with CacheModifiedInputIterator
118             InputIteratorT>::Type                                                           // Directly use the supplied input iterator type
119         WrappedInputIteratorT;
120 
121     // Constants
122     enum
123     {
124         IS_INCLUSIVE        = Equals<InitValueT, NullType>::VALUE,            // Inclusive scan if no init_value type is provided
125         BLOCK_THREADS       = AgentScanPolicyT::BLOCK_THREADS,
126         ITEMS_PER_THREAD    = AgentScanPolicyT::ITEMS_PER_THREAD,
127         TILE_ITEMS          = BLOCK_THREADS * ITEMS_PER_THREAD,
128     };
129 
130     // Parameterized BlockLoad type
131     typedef BlockLoad<
132             OutputT,
133             AgentScanPolicyT::BLOCK_THREADS,
134             AgentScanPolicyT::ITEMS_PER_THREAD,
135             AgentScanPolicyT::LOAD_ALGORITHM>
136         BlockLoadT;
137 
138     // Parameterized BlockStore type
139     typedef BlockStore<
140             OutputT,
141             AgentScanPolicyT::BLOCK_THREADS,
142             AgentScanPolicyT::ITEMS_PER_THREAD,
143             AgentScanPolicyT::STORE_ALGORITHM>
144         BlockStoreT;
145 
146     // Parameterized BlockScan type
147     typedef BlockScan<
148             OutputT,
149             AgentScanPolicyT::BLOCK_THREADS,
150             AgentScanPolicyT::SCAN_ALGORITHM>
151         BlockScanT;
152 
153     // Callback type for obtaining tile prefix during block scan
154     typedef TilePrefixCallbackOp<
155             OutputT,
156             ScanOpT,
157             ScanTileStateT>
158         TilePrefixCallbackOpT;
159 
160     // Stateful BlockScan prefix callback type for managing a running total while scanning consecutive tiles
161     typedef BlockScanRunningPrefixOp<
162             OutputT,
163             ScanOpT>
164         RunningPrefixCallbackOp;
165 
166     // Shared memory type for this thread block
167     union _TempStorage
168     {
169         typename BlockLoadT::TempStorage    load;       // Smem needed for tile loading
170         typename BlockStoreT::TempStorage   store;      // Smem needed for tile storing
171 
172         struct
173         {
174             typename TilePrefixCallbackOpT::TempStorage  prefix;     // Smem needed for cooperative prefix callback
175             typename BlockScanT::TempStorage             scan;       // Smem needed for tile scanning
176         };
177     };
178 
179     // Alias wrapper allowing storage to be unioned
180     struct TempStorage : Uninitialized<_TempStorage> {};
181 
182 
183     //---------------------------------------------------------------------
184     // Per-thread fields
185     //---------------------------------------------------------------------
186 
187     _TempStorage&               temp_storage;       ///< Reference to temp_storage
188     WrappedInputIteratorT       d_in;               ///< Input data
189     OutputIteratorT             d_out;              ///< Output data
190     ScanOpT                     scan_op;            ///< Binary scan operator
191     InitValueT                  init_value;         ///< The init_value element for ScanOpT
192 
193 
194     //---------------------------------------------------------------------
195     // Block scan utility methods
196     //---------------------------------------------------------------------
197 
198     /**
199      * Exclusive scan specialization (first tile)
200      */
201     __device__ __forceinline__
ScanTilecub::AgentScan202     void ScanTile(
203         OutputT             (&items)[ITEMS_PER_THREAD],
204         OutputT             init_value,
205         ScanOpT             scan_op,
206         OutputT             &block_aggregate,
207         Int2Type<false>     /*is_inclusive*/)
208     {
209         BlockScanT(temp_storage.scan).ExclusiveScan(items, items, init_value, scan_op, block_aggregate);
210         block_aggregate = scan_op(init_value, block_aggregate);
211     }
212 
213 
214     /**
215      * Inclusive scan specialization (first tile)
216      */
217     __device__ __forceinline__
ScanTilecub::AgentScan218     void ScanTile(
219         OutputT             (&items)[ITEMS_PER_THREAD],
220         InitValueT          /*init_value*/,
221         ScanOpT             scan_op,
222         OutputT             &block_aggregate,
223         Int2Type<true>      /*is_inclusive*/)
224     {
225         BlockScanT(temp_storage.scan).InclusiveScan(items, items, scan_op, block_aggregate);
226     }
227 
228 
229     /**
230      * Exclusive scan specialization (subsequent tiles)
231      */
232     template <typename PrefixCallback>
233     __device__ __forceinline__
ScanTilecub::AgentScan234     void ScanTile(
235         OutputT             (&items)[ITEMS_PER_THREAD],
236         ScanOpT             scan_op,
237         PrefixCallback      &prefix_op,
238         Int2Type<false>     /*is_inclusive*/)
239     {
240         BlockScanT(temp_storage.scan).ExclusiveScan(items, items, scan_op, prefix_op);
241     }
242 
243 
244     /**
245      * Inclusive scan specialization (subsequent tiles)
246      */
247     template <typename PrefixCallback>
248     __device__ __forceinline__
ScanTilecub::AgentScan249     void ScanTile(
250         OutputT             (&items)[ITEMS_PER_THREAD],
251         ScanOpT             scan_op,
252         PrefixCallback      &prefix_op,
253         Int2Type<true>      /*is_inclusive*/)
254     {
255         BlockScanT(temp_storage.scan).InclusiveScan(items, items, scan_op, prefix_op);
256     }
257 
258 
259     //---------------------------------------------------------------------
260     // Constructor
261     //---------------------------------------------------------------------
262 
263     // Constructor
264     __device__ __forceinline__
AgentScancub::AgentScan265     AgentScan(
266         TempStorage&    temp_storage,       ///< Reference to temp_storage
267         InputIteratorT  d_in,               ///< Input data
268         OutputIteratorT d_out,              ///< Output data
269         ScanOpT         scan_op,            ///< Binary scan operator
270         InitValueT      init_value)         ///< Initial value to seed the exclusive scan
271     :
272         temp_storage(temp_storage.Alias()),
273         d_in(d_in),
274         d_out(d_out),
275         scan_op(scan_op),
276         init_value(init_value)
277     {}
278 
279 
280     //---------------------------------------------------------------------
281     // Cooperatively scan a device-wide sequence of tiles with other CTAs
282     //---------------------------------------------------------------------
283 
284     /**
285      * Process a tile of input (dynamic chained scan)
286      */
287     template <bool IS_LAST_TILE>                ///< Whether the current tile is the last tile
ConsumeTilecub::AgentScan288     __device__ __forceinline__ void ConsumeTile(
289         OffsetT             num_remaining,      ///< Number of global input items remaining (including this tile)
290         int                 tile_idx,           ///< Tile index
291         OffsetT             tile_offset,        ///< Tile offset
292         ScanTileStateT&     tile_state)         ///< Global tile state descriptor
293     {
294         // Load items
295         OutputT items[ITEMS_PER_THREAD];
296 
297         if (IS_LAST_TILE)
298             BlockLoadT(temp_storage.load).Load(d_in + tile_offset, items, num_remaining);
299         else
300             BlockLoadT(temp_storage.load).Load(d_in + tile_offset, items);
301 
302         CTA_SYNC();
303 
304         // Perform tile scan
305         if (tile_idx == 0)
306         {
307             // Scan first tile
308             OutputT block_aggregate;
309             ScanTile(items, init_value, scan_op, block_aggregate, Int2Type<IS_INCLUSIVE>());
310             if ((!IS_LAST_TILE) && (threadIdx.x == 0))
311                 tile_state.SetInclusive(0, block_aggregate);
312         }
313         else
314         {
315             // Scan non-first tile
316             TilePrefixCallbackOpT prefix_op(tile_state, temp_storage.prefix, scan_op, tile_idx);
317             ScanTile(items, scan_op, prefix_op, Int2Type<IS_INCLUSIVE>());
318         }
319 
320         CTA_SYNC();
321 
322         // Store items
323         if (IS_LAST_TILE)
324             BlockStoreT(temp_storage.store).Store(d_out + tile_offset, items, num_remaining);
325         else
326             BlockStoreT(temp_storage.store).Store(d_out + tile_offset, items);
327     }
328 
329 
330     /**
331      * Scan tiles of items as part of a dynamic chained scan
332      */
ConsumeRangecub::AgentScan333     __device__ __forceinline__ void ConsumeRange(
334         int                 num_items,          ///< Total number of input items
335         ScanTileStateT&     tile_state,         ///< Global tile state descriptor
336         int                 start_tile)         ///< The starting tile for the current grid
337     {
338         // Blocks are launched in increasing order, so just assign one tile per block
339         int     tile_idx        = start_tile + blockIdx.x;          // Current tile index
340         OffsetT tile_offset     = OffsetT(TILE_ITEMS) * tile_idx;   // Global offset for the current tile
341         OffsetT num_remaining   = num_items - tile_offset;          // Remaining items (including this tile)
342 
343         if (num_remaining > TILE_ITEMS)
344         {
345             // Not last tile
346             ConsumeTile<false>(num_remaining, tile_idx, tile_offset, tile_state);
347         }
348         else if (num_remaining > 0)
349         {
350             // Last tile
351             ConsumeTile<true>(num_remaining, tile_idx, tile_offset, tile_state);
352         }
353     }
354 
355 
356     //---------------------------------------------------------------------
357     // Scan an sequence of consecutive tiles (independent of other thread blocks)
358     //---------------------------------------------------------------------
359 
360     /**
361      * Process a tile of input
362      */
363     template <
364         bool                        IS_FIRST_TILE,
365         bool                        IS_LAST_TILE>
ConsumeTilecub::AgentScan366     __device__ __forceinline__ void ConsumeTile(
367         OffsetT                     tile_offset,                ///< Tile offset
368         RunningPrefixCallbackOp&    prefix_op,                  ///< Running prefix operator
369         int                         valid_items = TILE_ITEMS)   ///< Number of valid items in the tile
370     {
371         // Load items
372         OutputT items[ITEMS_PER_THREAD];
373 
374         if (IS_LAST_TILE)
375             BlockLoadT(temp_storage.load).Load(d_in + tile_offset, items, valid_items);
376         else
377             BlockLoadT(temp_storage.load).Load(d_in + tile_offset, items);
378 
379         CTA_SYNC();
380 
381         // Block scan
382         if (IS_FIRST_TILE)
383         {
384             OutputT block_aggregate;
385             ScanTile(items, init_value, scan_op, block_aggregate, Int2Type<IS_INCLUSIVE>());
386             prefix_op.running_total = block_aggregate;
387         }
388         else
389         {
390             ScanTile(items, scan_op, prefix_op, Int2Type<IS_INCLUSIVE>());
391         }
392 
393         CTA_SYNC();
394 
395         // Store items
396         if (IS_LAST_TILE)
397             BlockStoreT(temp_storage.store).Store(d_out + tile_offset, items, valid_items);
398         else
399             BlockStoreT(temp_storage.store).Store(d_out + tile_offset, items);
400     }
401 
402 
403     /**
404      * Scan a consecutive share of input tiles
405      */
ConsumeRangecub::AgentScan406     __device__ __forceinline__ void ConsumeRange(
407         OffsetT  range_offset,      ///< [in] Threadblock begin offset (inclusive)
408         OffsetT  range_end)         ///< [in] Threadblock end offset (exclusive)
409     {
410         BlockScanRunningPrefixOp<OutputT, ScanOpT> prefix_op(scan_op);
411 
412         if (range_offset + TILE_ITEMS <= range_end)
413         {
414             // Consume first tile of input (full)
415             ConsumeTile<true, true>(range_offset, prefix_op);
416             range_offset += TILE_ITEMS;
417 
418             // Consume subsequent full tiles of input
419             while (range_offset + TILE_ITEMS <= range_end)
420             {
421                 ConsumeTile<false, true>(range_offset, prefix_op);
422                 range_offset += TILE_ITEMS;
423             }
424 
425             // Consume a partially-full tile
426             if (range_offset < range_end)
427             {
428                 int valid_items = range_end - range_offset;
429                 ConsumeTile<false, false>(range_offset, prefix_op, valid_items);
430             }
431         }
432         else
433         {
434             // Consume the first tile of input (partially-full)
435             int valid_items = range_end - range_offset;
436             ConsumeTile<true, false>(range_offset, prefix_op, valid_items);
437         }
438     }
439 
440 
441     /**
442      * Scan a consecutive share of input tiles, seeded with the specified prefix value
443      */
ConsumeRangecub::AgentScan444     __device__ __forceinline__ void ConsumeRange(
445         OffsetT range_offset,                       ///< [in] Threadblock begin offset (inclusive)
446         OffsetT range_end,                          ///< [in] Threadblock end offset (exclusive)
447         OutputT prefix)                             ///< [in] The prefix to apply to the scan segment
448     {
449         BlockScanRunningPrefixOp<OutputT, ScanOpT> prefix_op(prefix, scan_op);
450 
451         // Consume full tiles of input
452         while (range_offset + TILE_ITEMS <= range_end)
453         {
454             ConsumeTile<true, false>(range_offset, prefix_op);
455             range_offset += TILE_ITEMS;
456         }
457 
458         // Consume a partially-full tile
459         if (range_offset < range_end)
460         {
461             int valid_items = range_end - range_offset;
462             ConsumeTile<false, false>(range_offset, prefix_op, valid_items);
463         }
464     }
465 
466 };
467 
468 
469 }               // CUB namespace
470 CUB_NS_POSTFIX  // Optional outer namespace(s)
471 
472