1 /******************************************************************************
2  * Copyright (c) 2011, Duane Merrill.  All rights reserved.
3  * Copyright (c) 2011-2018, NVIDIA CORPORATION.  All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are met:
7  *     * Redistributions of source code must retain the above copyright
8  *       notice, this list of conditions and the following disclaimer.
9  *     * Redistributions in binary form must reproduce the above copyright
10  *       notice, this list of conditions and the following disclaimer in the
11  *       documentation and/or other materials provided with the distribution.
12  *     * Neither the name of the NVIDIA CORPORATION nor the
13  *       names of its contributors may be used to endorse or promote products
14  *       derived from this software without specific prior written permission.
15  *
16  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
17  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19  * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
20  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26  *
27  ******************************************************************************/
28 
29 /**
30  * \file
31  * cub::AgentSegmentFixup implements a stateful abstraction of CUDA thread blocks for participating in device-wide reduce-value-by-key.
32  */
33 
34 #pragma once
35 
36 #include <iterator>
37 
38 #include "single_pass_scan_operators.cuh"
39 #include "../block/block_load.cuh"
40 #include "../block/block_store.cuh"
41 #include "../block/block_scan.cuh"
42 #include "../block/block_discontinuity.cuh"
43 #include "../iterator/cache_modified_input_iterator.cuh"
44 #include "../iterator/constant_input_iterator.cuh"
45 #include "../util_namespace.cuh"
46 
47 /// Optional outer namespace(s)
48 CUB_NS_PREFIX
49 
50 /// CUB namespace
51 namespace cub {
52 
53 
54 /******************************************************************************
55  * Tuning policy types
56  ******************************************************************************/
57 
58 /**
59  * Parameterizable tuning policy type for AgentSegmentFixup
60  */
61 template <
62     int                         _BLOCK_THREADS,                 ///< Threads per thread block
63     int                         _ITEMS_PER_THREAD,              ///< Items per thread (per tile of input)
64     BlockLoadAlgorithm          _LOAD_ALGORITHM,                ///< The BlockLoad algorithm to use
65     CacheLoadModifier           _LOAD_MODIFIER,                 ///< Cache load modifier for reading input elements
66     BlockScanAlgorithm          _SCAN_ALGORITHM>                ///< The BlockScan algorithm to use
67 struct AgentSegmentFixupPolicy
68 {
69     enum
70     {
71         BLOCK_THREADS           = _BLOCK_THREADS,               ///< Threads per thread block
72         ITEMS_PER_THREAD        = _ITEMS_PER_THREAD,            ///< Items per thread (per tile of input)
73     };
74 
75     static const BlockLoadAlgorithm     LOAD_ALGORITHM          = _LOAD_ALGORITHM;      ///< The BlockLoad algorithm to use
76     static const CacheLoadModifier      LOAD_MODIFIER           = _LOAD_MODIFIER;       ///< Cache load modifier for reading input elements
77     static const BlockScanAlgorithm     SCAN_ALGORITHM          = _SCAN_ALGORITHM;      ///< The BlockScan algorithm to use
78 };
79 
80 
81 /******************************************************************************
82  * Thread block abstractions
83  ******************************************************************************/
84 
85 /**
86  * \brief AgentSegmentFixup implements a stateful abstraction of CUDA thread blocks for participating in device-wide reduce-value-by-key
87  */
88 template <
89     typename    AgentSegmentFixupPolicyT,       ///< Parameterized AgentSegmentFixupPolicy tuning policy type
90     typename    PairsInputIteratorT,            ///< Random-access input iterator type for keys
91     typename    AggregatesOutputIteratorT,      ///< Random-access output iterator type for values
92     typename    EqualityOpT,                    ///< KeyT equality operator type
93     typename    ReductionOpT,                   ///< ValueT reduction operator type
94     typename    OffsetT>                        ///< Signed integer type for global offsets
95 struct AgentSegmentFixup
96 {
97     //---------------------------------------------------------------------
98     // Types and constants
99     //---------------------------------------------------------------------
100 
101     // Data type of key-value input iterator
102     typedef typename std::iterator_traits<PairsInputIteratorT>::value_type KeyValuePairT;
103 
104     // Value type
105     typedef typename KeyValuePairT::Value ValueT;
106 
107     // Tile status descriptor interface type
108     typedef ReduceByKeyScanTileState<ValueT, OffsetT> ScanTileStateT;
109 
110     // Constants
111     enum
112     {
113         BLOCK_THREADS       = AgentSegmentFixupPolicyT::BLOCK_THREADS,
114         ITEMS_PER_THREAD    = AgentSegmentFixupPolicyT::ITEMS_PER_THREAD,
115         TILE_ITEMS          = BLOCK_THREADS * ITEMS_PER_THREAD,
116 
117         // Whether or not do fixup using RLE + global atomics
118         USE_ATOMIC_FIXUP    = (CUB_PTX_ARCH >= 350) &&
119                                 (Equals<ValueT, float>::VALUE ||
120                                  Equals<ValueT, int>::VALUE ||
121                                  Equals<ValueT, unsigned int>::VALUE ||
122                                  Equals<ValueT, unsigned long long>::VALUE),
123 
124         // Whether or not the scan operation has a zero-valued identity value (true if we're performing addition on a primitive type)
125         HAS_IDENTITY_ZERO   = (Equals<ReductionOpT, cub::Sum>::VALUE) && (Traits<ValueT>::PRIMITIVE),
126     };
127 
128     // Cache-modified Input iterator wrapper type (for applying cache modifier) for keys
129     typedef typename If<IsPointer<PairsInputIteratorT>::VALUE,
130             CacheModifiedInputIterator<AgentSegmentFixupPolicyT::LOAD_MODIFIER, KeyValuePairT, OffsetT>,    // Wrap the native input pointer with CacheModifiedValuesInputIterator
131             PairsInputIteratorT>::Type                                                                      // Directly use the supplied input iterator type
132         WrappedPairsInputIteratorT;
133 
134     // Cache-modified Input iterator wrapper type (for applying cache modifier) for fixup values
135     typedef typename If<IsPointer<AggregatesOutputIteratorT>::VALUE,
136             CacheModifiedInputIterator<AgentSegmentFixupPolicyT::LOAD_MODIFIER, ValueT, OffsetT>,    // Wrap the native input pointer with CacheModifiedValuesInputIterator
137             AggregatesOutputIteratorT>::Type                                                        // Directly use the supplied input iterator type
138         WrappedFixupInputIteratorT;
139 
140     // Reduce-value-by-segment scan operator
141     typedef ReduceByKeyOp<cub::Sum> ReduceBySegmentOpT;
142 
143     // Parameterized BlockLoad type for pairs
144     typedef BlockLoad<
145             KeyValuePairT,
146             BLOCK_THREADS,
147             ITEMS_PER_THREAD,
148             AgentSegmentFixupPolicyT::LOAD_ALGORITHM>
149         BlockLoadPairs;
150 
151     // Parameterized BlockScan type
152     typedef BlockScan<
153             KeyValuePairT,
154             BLOCK_THREADS,
155             AgentSegmentFixupPolicyT::SCAN_ALGORITHM>
156         BlockScanT;
157 
158     // Callback type for obtaining tile prefix during block scan
159     typedef TilePrefixCallbackOp<
160             KeyValuePairT,
161             ReduceBySegmentOpT,
162             ScanTileStateT>
163         TilePrefixCallbackOpT;
164 
165     // Shared memory type for this thread block
166     union _TempStorage
167     {
168         struct
169         {
170             typename BlockScanT::TempStorage                scan;           // Smem needed for tile scanning
171             typename TilePrefixCallbackOpT::TempStorage     prefix;         // Smem needed for cooperative prefix callback
172         };
173 
174         // Smem needed for loading keys
175         typename BlockLoadPairs::TempStorage load_pairs;
176     };
177 
178     // Alias wrapper allowing storage to be unioned
179     struct TempStorage : Uninitialized<_TempStorage> {};
180 
181 
182     //---------------------------------------------------------------------
183     // Per-thread fields
184     //---------------------------------------------------------------------
185 
186     _TempStorage&                   temp_storage;       ///< Reference to temp_storage
187     WrappedPairsInputIteratorT      d_pairs_in;          ///< Input keys
188     AggregatesOutputIteratorT       d_aggregates_out;   ///< Output value aggregates
189     WrappedFixupInputIteratorT      d_fixup_in;         ///< Fixup input values
190     InequalityWrapper<EqualityOpT>  inequality_op;      ///< KeyT inequality operator
191     ReductionOpT                    reduction_op;       ///< Reduction operator
192     ReduceBySegmentOpT              scan_op;            ///< Reduce-by-segment scan operator
193 
194 
195     //---------------------------------------------------------------------
196     // Constructor
197     //---------------------------------------------------------------------
198 
199     // Constructor
200     __device__ __forceinline__
AgentSegmentFixupcub::AgentSegmentFixup201     AgentSegmentFixup(
202         TempStorage&                temp_storage,       ///< Reference to temp_storage
203         PairsInputIteratorT         d_pairs_in,          ///< Input keys
204         AggregatesOutputIteratorT   d_aggregates_out,   ///< Output value aggregates
205         EqualityOpT                 equality_op,        ///< KeyT equality operator
206         ReductionOpT                reduction_op)       ///< ValueT reduction operator
207     :
208         temp_storage(temp_storage.Alias()),
209         d_pairs_in(d_pairs_in),
210         d_aggregates_out(d_aggregates_out),
211         d_fixup_in(d_aggregates_out),
212         inequality_op(equality_op),
213         reduction_op(reduction_op),
214         scan_op(reduction_op)
215     {}
216 
217 
218     //---------------------------------------------------------------------
219     // Cooperatively scan a device-wide sequence of tiles with other CTAs
220     //---------------------------------------------------------------------
221 
222 
223     /**
224      * Process input tile.  Specialized for atomic-fixup
225      */
226     template <bool IS_LAST_TILE>
ConsumeTilecub::AgentSegmentFixup227     __device__ __forceinline__ void ConsumeTile(
228         OffsetT             num_remaining,      ///< Number of global input items remaining (including this tile)
229         int                 tile_idx,           ///< Tile index
230         OffsetT             tile_offset,        ///< Tile offset
231         ScanTileStateT&     tile_state,         ///< Global tile state descriptor
232         Int2Type<true>      use_atomic_fixup)   ///< Marker whether to use atomicAdd (instead of reduce-by-key)
233     {
234         KeyValuePairT   pairs[ITEMS_PER_THREAD];
235 
236         // Load pairs
237         KeyValuePairT oob_pair;
238         oob_pair.key = -1;
239 
240         if (IS_LAST_TILE)
241             BlockLoadPairs(temp_storage.load_pairs).Load(d_pairs_in + tile_offset, pairs, num_remaining, oob_pair);
242         else
243             BlockLoadPairs(temp_storage.load_pairs).Load(d_pairs_in + tile_offset, pairs);
244 
245         // RLE
246         #pragma unroll
247         for (int ITEM = 1; ITEM < ITEMS_PER_THREAD; ++ITEM)
248         {
249             ValueT* d_scatter = d_aggregates_out + pairs[ITEM - 1].key;
250             if (pairs[ITEM].key != pairs[ITEM - 1].key)
251                 atomicAdd(d_scatter, pairs[ITEM - 1].value);
252             else
253                 pairs[ITEM].value = reduction_op(pairs[ITEM - 1].value, pairs[ITEM].value);
254         }
255 
256         // Flush last item if valid
257         ValueT* d_scatter = d_aggregates_out + pairs[ITEMS_PER_THREAD - 1].key;
258         if ((!IS_LAST_TILE) || (pairs[ITEMS_PER_THREAD - 1].key >= 0))
259             atomicAdd(d_scatter, pairs[ITEMS_PER_THREAD - 1].value);
260     }
261 
262 
263     /**
264      * Process input tile.  Specialized for reduce-by-key fixup
265      */
266     template <bool IS_LAST_TILE>
ConsumeTilecub::AgentSegmentFixup267     __device__ __forceinline__ void ConsumeTile(
268         OffsetT             num_remaining,      ///< Number of global input items remaining (including this tile)
269         int                 tile_idx,           ///< Tile index
270         OffsetT             tile_offset,        ///< Tile offset
271         ScanTileStateT&     tile_state,         ///< Global tile state descriptor
272         Int2Type<false>     use_atomic_fixup)   ///< Marker whether to use atomicAdd (instead of reduce-by-key)
273     {
274         KeyValuePairT   pairs[ITEMS_PER_THREAD];
275         KeyValuePairT   scatter_pairs[ITEMS_PER_THREAD];
276 
277         // Load pairs
278         KeyValuePairT oob_pair;
279         oob_pair.key = -1;
280 
281         if (IS_LAST_TILE)
282             BlockLoadPairs(temp_storage.load_pairs).Load(d_pairs_in + tile_offset, pairs, num_remaining, oob_pair);
283         else
284             BlockLoadPairs(temp_storage.load_pairs).Load(d_pairs_in + tile_offset, pairs);
285 
286         CTA_SYNC();
287 
288         KeyValuePairT tile_aggregate;
289         if (tile_idx == 0)
290         {
291             // Exclusive scan of values and segment_flags
292             BlockScanT(temp_storage.scan).ExclusiveScan(pairs, scatter_pairs, scan_op, tile_aggregate);
293 
294             // Update tile status if this is not the last tile
295             if (threadIdx.x == 0)
296             {
297                 // Set first segment id to not trigger a flush (invalid from exclusive scan)
298                 scatter_pairs[0].key = pairs[0].key;
299 
300                 if (!IS_LAST_TILE)
301                     tile_state.SetInclusive(0, tile_aggregate);
302 
303             }
304         }
305         else
306         {
307             // Exclusive scan of values and segment_flags
308             TilePrefixCallbackOpT prefix_op(tile_state, temp_storage.prefix, scan_op, tile_idx);
309             BlockScanT(temp_storage.scan).ExclusiveScan(pairs, scatter_pairs, scan_op, prefix_op);
310             tile_aggregate = prefix_op.GetBlockAggregate();
311         }
312 
313         // Scatter updated values
314         #pragma unroll
315         for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
316         {
317             if (scatter_pairs[ITEM].key != pairs[ITEM].key)
318             {
319                 // Update the value at the key location
320                 ValueT value    = d_fixup_in[scatter_pairs[ITEM].key];
321                 value           = reduction_op(value, scatter_pairs[ITEM].value);
322 
323                 d_aggregates_out[scatter_pairs[ITEM].key] = value;
324             }
325         }
326 
327         // Finalize the last item
328         if (IS_LAST_TILE)
329         {
330             // Last thread will output final count and last item, if necessary
331             if (threadIdx.x == BLOCK_THREADS - 1)
332             {
333                 // If the last tile is a whole tile, the inclusive prefix contains accumulated value reduction for the last segment
334                 if (num_remaining == TILE_ITEMS)
335                 {
336                     // Update the value at the key location
337                     OffsetT last_key = pairs[ITEMS_PER_THREAD - 1].key;
338                     d_aggregates_out[last_key] = reduction_op(tile_aggregate.value, d_fixup_in[last_key]);
339                 }
340             }
341         }
342     }
343 
344 
345     /**
346      * Scan tiles of items as part of a dynamic chained scan
347      */
ConsumeRangecub::AgentSegmentFixup348     __device__ __forceinline__ void ConsumeRange(
349         int                 num_items,          ///< Total number of input items
350         int                 num_tiles,          ///< Total number of input tiles
351         ScanTileStateT&     tile_state)         ///< Global tile state descriptor
352     {
353         // Blocks are launched in increasing order, so just assign one tile per block
354         int     tile_idx        = (blockIdx.x * gridDim.y) + blockIdx.y;    // Current tile index
355         OffsetT tile_offset     = tile_idx * TILE_ITEMS;                    // Global offset for the current tile
356         OffsetT num_remaining   = num_items - tile_offset;                  // Remaining items (including this tile)
357 
358         if (num_remaining > TILE_ITEMS)
359         {
360             // Not the last tile (full)
361             ConsumeTile<false>(num_remaining, tile_idx, tile_offset, tile_state, Int2Type<USE_ATOMIC_FIXUP>());
362         }
363         else if (num_remaining > 0)
364         {
365             // The last tile (possibly partially-full)
366             ConsumeTile<true>(num_remaining, tile_idx, tile_offset, tile_state, Int2Type<USE_ATOMIC_FIXUP>());
367         }
368     }
369 
370 };
371 
372 
373 }               // CUB namespace
374 CUB_NS_POSTFIX  // Optional outer namespace(s)
375 
376