1 /******************************************************************************
2  * Copyright (c) 2011, Duane Merrill.  All rights reserved.
3  * Copyright (c) 2011-2018, NVIDIA CORPORATION.  All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are met:
7  *     * Redistributions of source code must retain the above copyright
8  *       notice, this list of conditions and the following disclaimer.
9  *     * Redistributions in binary form must reproduce the above copyright
10  *       notice, this list of conditions and the following disclaimer in the
11  *       documentation and/or other materials provided with the distribution.
12  *     * Neither the name of the NVIDIA CORPORATION nor the
13  *       names of its contributors may be used to endorse or promote products
14  *       derived from this software without specific prior written permission.
15  *
16  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
17  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19  * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
20  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26  *
27  ******************************************************************************/
28 
29 /**
30  * \file
31  * cub::WarpScanSmem provides smem-based variants of parallel prefix scan of items partitioned across a CUDA thread warp.
32  */
33 
34 #pragma once
35 
36 #include "../../thread/thread_operators.cuh"
37 #include "../../thread/thread_load.cuh"
38 #include "../../thread/thread_store.cuh"
39 #include "../../util_type.cuh"
40 #include "../../util_namespace.cuh"
41 
42 /// Optional outer namespace(s)
43 CUB_NS_PREFIX
44 
45 /// CUB namespace
46 namespace cub {
47 
48 /**
49  * \brief WarpScanSmem provides smem-based variants of parallel prefix scan of items partitioned across a CUDA thread warp.
50  */
51 template <
52     typename    T,                      ///< Data type being scanned
53     int         LOGICAL_WARP_THREADS,   ///< Number of threads per logical warp
54     int         PTX_ARCH>               ///< The PTX compute capability for which to to specialize this collective
55 struct WarpScanSmem
56 {
57     /******************************************************************************
58      * Constants and type definitions
59      ******************************************************************************/
60 
61     enum
62     {
63         /// Whether the logical warp size and the PTX warp size coincide
64         IS_ARCH_WARP = (LOGICAL_WARP_THREADS == CUB_WARP_THREADS(PTX_ARCH)),
65 
66         /// Whether the logical warp size is a power-of-two
67         IS_POW_OF_TWO = PowerOfTwo<LOGICAL_WARP_THREADS>::VALUE,
68 
69         /// The number of warp scan steps
70         STEPS = Log2<LOGICAL_WARP_THREADS>::VALUE,
71 
72         /// The number of threads in half a warp
73         HALF_WARP_THREADS = 1 << (STEPS - 1),
74 
75         /// The number of shared memory elements per warp
76         WARP_SMEM_ELEMENTS =  LOGICAL_WARP_THREADS + HALF_WARP_THREADS,
77     };
78 
79     /// Storage cell type (workaround for SM1x compiler bugs with custom-ops like Max() on signed chars)
80     typedef typename If<((Equals<T, char>::VALUE || Equals<T, signed char>::VALUE) && (PTX_ARCH < 200)), int, T>::Type CellT;
81 
82     /// Shared memory storage layout type (1.5 warps-worth of elements for each warp)
83     typedef CellT _TempStorage[WARP_SMEM_ELEMENTS];
84 
85     // Alias wrapper allowing storage to be unioned
86     struct TempStorage : Uninitialized<_TempStorage> {};
87 
88 
89     /******************************************************************************
90      * Thread fields
91      ******************************************************************************/
92 
93     _TempStorage    &temp_storage;
94     unsigned int    lane_id;
95     unsigned int    member_mask;
96 
97 
98     /******************************************************************************
99      * Construction
100      ******************************************************************************/
101 
102     /// Constructor
WarpScanSmemcub::WarpScanSmem103     __device__ __forceinline__ WarpScanSmem(
104         TempStorage     &temp_storage)
105     :
106         temp_storage(temp_storage.Alias()),
107 
108         lane_id(IS_ARCH_WARP ?
109             LaneId() :
110             LaneId() % LOGICAL_WARP_THREADS),
111 
112         member_mask((0xffffffff >> (32 - LOGICAL_WARP_THREADS)) << ((IS_ARCH_WARP || !IS_POW_OF_TWO ) ?
113             0 : // arch-width and non-power-of-two subwarps cannot be tiled with the arch-warp
114             ((LaneId() / LOGICAL_WARP_THREADS) * LOGICAL_WARP_THREADS)))
115     {}
116 
117 
118     /******************************************************************************
119      * Utility methods
120      ******************************************************************************/
121 
122     /// Basic inclusive scan iteration (template unrolled, inductive-case specialization)
123     template <
124         bool        HAS_IDENTITY,
125         int         STEP,
126         typename    ScanOp>
ScanStepcub::WarpScanSmem127     __device__ __forceinline__ void ScanStep(
128         T                       &partial,
129         ScanOp                  scan_op,
130         Int2Type<STEP>          /*step*/)
131     {
132         const int OFFSET = 1 << STEP;
133 
134         // Share partial into buffer
135         ThreadStore<STORE_VOLATILE>(&temp_storage[HALF_WARP_THREADS + lane_id], (CellT) partial);
136 
137         WARP_SYNC(member_mask);
138 
139         // Update partial if addend is in range
140         if (HAS_IDENTITY || (lane_id >= OFFSET))
141         {
142             T addend = (T) ThreadLoad<LOAD_VOLATILE>(&temp_storage[HALF_WARP_THREADS + lane_id - OFFSET]);
143             partial = scan_op(addend, partial);
144         }
145         WARP_SYNC(member_mask);
146 
147         ScanStep<HAS_IDENTITY>(partial, scan_op, Int2Type<STEP + 1>());
148     }
149 
150 
151     /// Basic inclusive scan iteration(template unrolled, base-case specialization)
152     template <
153         bool        HAS_IDENTITY,
154         typename    ScanOp>
ScanStepcub::WarpScanSmem155     __device__ __forceinline__ void ScanStep(
156         T                       &/*partial*/,
157         ScanOp                  /*scan_op*/,
158         Int2Type<STEPS>         /*step*/)
159     {}
160 
161 
162     /// Inclusive prefix scan (specialized for summation across primitive types)
InclusiveScancub::WarpScanSmem163     __device__ __forceinline__ void InclusiveScan(
164         T                       input,              ///< [in] Calling thread's input item.
165         T                       &output,            ///< [out] Calling thread's output item.  May be aliased with \p input.
166         Sum                     scan_op,            ///< [in] Binary scan operator
167         Int2Type<true>          /*is_primitive*/)   ///< [in] Marker type indicating whether T is primitive type
168     {
169         T identity = 0;
170         ThreadStore<STORE_VOLATILE>(&temp_storage[lane_id], (CellT) identity);
171 
172         WARP_SYNC(member_mask);
173 
174         // Iterate scan steps
175         output = input;
176         ScanStep<true>(output, scan_op, Int2Type<0>());
177     }
178 
179 
180     /// Inclusive prefix scan
181     template <typename ScanOp, int IS_PRIMITIVE>
InclusiveScancub::WarpScanSmem182     __device__ __forceinline__ void InclusiveScan(
183         T                       input,              ///< [in] Calling thread's input item.
184         T                       &output,            ///< [out] Calling thread's output item.  May be aliased with \p input.
185         ScanOp                  scan_op,            ///< [in] Binary scan operator
186         Int2Type<IS_PRIMITIVE>  /*is_primitive*/)   ///< [in] Marker type indicating whether T is primitive type
187     {
188         // Iterate scan steps
189         output = input;
190         ScanStep<false>(output, scan_op, Int2Type<0>());
191     }
192 
193 
194     /******************************************************************************
195      * Interface
196      ******************************************************************************/
197 
198     //---------------------------------------------------------------------
199     // Broadcast
200     //---------------------------------------------------------------------
201 
202     /// Broadcast
Broadcastcub::WarpScanSmem203     __device__ __forceinline__ T Broadcast(
204         T               input,              ///< [in] The value to broadcast
205         unsigned int    src_lane)           ///< [in] Which warp lane is to do the broadcasting
206     {
207         if (lane_id == src_lane)
208         {
209             ThreadStore<STORE_VOLATILE>(temp_storage, (CellT) input);
210         }
211 
212         WARP_SYNC(member_mask);
213 
214         return (T)ThreadLoad<LOAD_VOLATILE>(temp_storage);
215     }
216 
217 
218     //---------------------------------------------------------------------
219     // Inclusive operations
220     //---------------------------------------------------------------------
221 
222     /// Inclusive scan
223     template <typename ScanOp>
InclusiveScancub::WarpScanSmem224     __device__ __forceinline__ void InclusiveScan(
225         T               input,              ///< [in] Calling thread's input item.
226         T               &inclusive_output,  ///< [out] Calling thread's output item.  May be aliased with \p input.
227         ScanOp          scan_op)            ///< [in] Binary scan operator
228     {
229         InclusiveScan(input, inclusive_output, scan_op, Int2Type<Traits<T>::PRIMITIVE>());
230     }
231 
232 
233     /// Inclusive scan with aggregate
234     template <typename ScanOp>
InclusiveScancub::WarpScanSmem235     __device__ __forceinline__ void InclusiveScan(
236         T               input,              ///< [in] Calling thread's input item.
237         T               &inclusive_output,  ///< [out] Calling thread's output item.  May be aliased with \p input.
238         ScanOp          scan_op,            ///< [in] Binary scan operator
239         T               &warp_aggregate)    ///< [out] Warp-wide aggregate reduction of input items.
240     {
241         InclusiveScan(input, inclusive_output, scan_op);
242 
243         // Retrieve aggregate
244         ThreadStore<STORE_VOLATILE>(&temp_storage[HALF_WARP_THREADS + lane_id], (CellT) inclusive_output);
245 
246         WARP_SYNC(member_mask);
247 
248         warp_aggregate = (T) ThreadLoad<LOAD_VOLATILE>(&temp_storage[WARP_SMEM_ELEMENTS - 1]);
249 
250         WARP_SYNC(member_mask);
251     }
252 
253 
254     //---------------------------------------------------------------------
255     // Get exclusive from inclusive
256     //---------------------------------------------------------------------
257 
258     /// Update inclusive and exclusive using input and inclusive
259     template <typename ScanOpT, typename IsIntegerT>
Updatecub::WarpScanSmem260     __device__ __forceinline__ void Update(
261         T                       /*input*/,      ///< [in]
262         T                       &inclusive,     ///< [in, out]
263         T                       &exclusive,     ///< [out]
264         ScanOpT                 /*scan_op*/,    ///< [in]
265         IsIntegerT              /*is_integer*/) ///< [in]
266     {
267         // initial value unknown
268         ThreadStore<STORE_VOLATILE>(&temp_storage[HALF_WARP_THREADS + lane_id], (CellT) inclusive);
269 
270         WARP_SYNC(member_mask);
271 
272         exclusive = (T) ThreadLoad<LOAD_VOLATILE>(&temp_storage[HALF_WARP_THREADS + lane_id - 1]);
273     }
274 
275     /// Update inclusive and exclusive using input and inclusive (specialized for summation of integer types)
Updatecub::WarpScanSmem276     __device__ __forceinline__ void Update(
277         T                       input,
278         T                       &inclusive,
279         T                       &exclusive,
280         cub::Sum                /*scan_op*/,
281         Int2Type<true>          /*is_integer*/)
282     {
283         // initial value presumed 0
284         exclusive = inclusive - input;
285     }
286 
287     /// Update inclusive and exclusive using initial value using input, inclusive, and initial value
288     template <typename ScanOpT, typename IsIntegerT>
Updatecub::WarpScanSmem289     __device__ __forceinline__ void Update (
290         T                       /*input*/,
291         T                       &inclusive,
292         T                       &exclusive,
293         ScanOpT                 scan_op,
294         T                       initial_value,
295         IsIntegerT              /*is_integer*/)
296     {
297         inclusive = scan_op(initial_value, inclusive);
298         ThreadStore<STORE_VOLATILE>(&temp_storage[HALF_WARP_THREADS + lane_id], (CellT) inclusive);
299 
300         WARP_SYNC(member_mask);
301 
302         exclusive = (T) ThreadLoad<LOAD_VOLATILE>(&temp_storage[HALF_WARP_THREADS + lane_id - 1]);
303         if (lane_id == 0)
304             exclusive = initial_value;
305     }
306 
307     /// Update inclusive and exclusive using initial value using input and inclusive (specialized for summation of integer types)
Updatecub::WarpScanSmem308     __device__ __forceinline__ void Update (
309         T                       input,
310         T                       &inclusive,
311         T                       &exclusive,
312         cub::Sum                scan_op,
313         T                       initial_value,
314         Int2Type<true>          /*is_integer*/)
315     {
316         inclusive = scan_op(initial_value, inclusive);
317         exclusive = inclusive - input;
318     }
319 
320 
321     /// Update inclusive, exclusive, and warp aggregate using input and inclusive
322     template <typename ScanOpT, typename IsIntegerT>
Updatecub::WarpScanSmem323     __device__ __forceinline__ void Update (
324         T                       /*input*/,
325         T                       &inclusive,
326         T                       &exclusive,
327         T                       &warp_aggregate,
328         ScanOpT                 /*scan_op*/,
329         IsIntegerT              /*is_integer*/)
330     {
331         // Initial value presumed to be unknown or identity (either way our padding is correct)
332         ThreadStore<STORE_VOLATILE>(&temp_storage[HALF_WARP_THREADS + lane_id], (CellT) inclusive);
333 
334         WARP_SYNC(member_mask);
335 
336         exclusive = (T) ThreadLoad<LOAD_VOLATILE>(&temp_storage[HALF_WARP_THREADS + lane_id - 1]);
337         warp_aggregate = (T) ThreadLoad<LOAD_VOLATILE>(&temp_storage[WARP_SMEM_ELEMENTS - 1]);
338     }
339 
340     /// Update inclusive, exclusive, and warp aggregate using input and inclusive (specialized for summation of integer types)
Updatecub::WarpScanSmem341     __device__ __forceinline__ void Update (
342         T                       input,
343         T                       &inclusive,
344         T                       &exclusive,
345         T                       &warp_aggregate,
346         cub::Sum                /*scan_o*/,
347         Int2Type<true>          /*is_integer*/)
348     {
349         // Initial value presumed to be unknown or identity (either way our padding is correct)
350         ThreadStore<STORE_VOLATILE>(&temp_storage[HALF_WARP_THREADS + lane_id], (CellT) inclusive);
351 
352         WARP_SYNC(member_mask);
353 
354         warp_aggregate = (T) ThreadLoad<LOAD_VOLATILE>(&temp_storage[WARP_SMEM_ELEMENTS - 1]);
355         exclusive = inclusive - input;
356     }
357 
358     /// Update inclusive, exclusive, and warp aggregate using input, inclusive, and initial value
359     template <typename ScanOpT, typename IsIntegerT>
Updatecub::WarpScanSmem360     __device__ __forceinline__ void Update (
361         T                       /*input*/,
362         T                       &inclusive,
363         T                       &exclusive,
364         T                       &warp_aggregate,
365         ScanOpT                 scan_op,
366         T                       initial_value,
367         IsIntegerT              /*is_integer*/)
368     {
369         // Broadcast warp aggregate
370         ThreadStore<STORE_VOLATILE>(&temp_storage[HALF_WARP_THREADS + lane_id], (CellT) inclusive);
371 
372         WARP_SYNC(member_mask);
373 
374         warp_aggregate = (T) ThreadLoad<LOAD_VOLATILE>(&temp_storage[WARP_SMEM_ELEMENTS - 1]);
375 
376         WARP_SYNC(member_mask);
377 
378         // Update inclusive with initial value
379         inclusive = scan_op(initial_value, inclusive);
380 
381         // Get exclusive from exclusive
382         ThreadStore<STORE_VOLATILE>(&temp_storage[HALF_WARP_THREADS + lane_id - 1], (CellT) inclusive);
383 
384         WARP_SYNC(member_mask);
385 
386         exclusive = (T) ThreadLoad<LOAD_VOLATILE>(&temp_storage[HALF_WARP_THREADS + lane_id - 2]);
387 
388         if (lane_id == 0)
389             exclusive = initial_value;
390     }
391 
392 
393 };
394 
395 
396 }               // CUB namespace
397 CUB_NS_POSTFIX  // Optional outer namespace(s)
398