1 /******************************************************************************
2  * Copyright (c) 2011, Duane Merrill.  All rights reserved.
3  * Copyright (c) 2011-2018, NVIDIA CORPORATION.  All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are met:
7  *     * Redistributions of source code must retain the above copyright
8  *       notice, this list of conditions and the following disclaimer.
9  *     * Redistributions in binary form must reproduce the above copyright
10  *       notice, this list of conditions and the following disclaimer in the
11  *       documentation and/or other materials provided with the distribution.
12  *     * Neither the name of the NVIDIA CORPORATION nor the
13  *       names of its contributors may be used to endorse or promote products
14  *       derived from this software without specific prior written permission.
15  *
16  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
17  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19  * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
20  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26  *
27  ******************************************************************************/
28 
29 /**
30  * \file
31  * PTX intrinsics
32  */
33 
34 
35 #pragma once
36 
37 #include "util_type.cuh"
38 #include "util_arch.cuh"
39 #include "util_namespace.cuh"
40 #include "util_debug.cuh"
41 
42 
43 /// Optional outer namespace(s)
44 CUB_NS_PREFIX
45 
46 /// CUB namespace
47 namespace cub {
48 
49 
50 /**
51  * \addtogroup UtilPtx
52  * @{
53  */
54 
55 
56 /******************************************************************************
57  * PTX helper macros
58  ******************************************************************************/
59 
60 #ifndef DOXYGEN_SHOULD_SKIP_THIS    // Do not document
61 
62 /**
63  * Register modifier for pointer-types (for inlining PTX assembly)
64  */
65 #if defined(_WIN64) || defined(__LP64__)
66     #define __CUB_LP64__ 1
67     // 64-bit register modifier for inlined asm
68     #define _CUB_ASM_PTR_ "l"
69     #define _CUB_ASM_PTR_SIZE_ "u64"
70 #else
71     #define __CUB_LP64__ 0
72     // 32-bit register modifier for inlined asm
73     #define _CUB_ASM_PTR_ "r"
74     #define _CUB_ASM_PTR_SIZE_ "u32"
75 #endif
76 
77 #endif // DOXYGEN_SHOULD_SKIP_THIS
78 
79 
80 /******************************************************************************
81  * Inlined PTX intrinsics
82  ******************************************************************************/
83 
84 /**
85  * \brief Shift-right then add.  Returns (\p x >> \p shift) + \p addend.
86  */
SHR_ADD(unsigned int x,unsigned int shift,unsigned int addend)87 __device__ __forceinline__ unsigned int SHR_ADD(
88     unsigned int x,
89     unsigned int shift,
90     unsigned int addend)
91 {
92     unsigned int ret;
93 #if CUB_PTX_ARCH >= 200
94     asm ("vshr.u32.u32.u32.clamp.add %0, %1, %2, %3;" :
95         "=r"(ret) : "r"(x), "r"(shift), "r"(addend));
96 #else
97     ret = (x >> shift) + addend;
98 #endif
99     return ret;
100 }
101 
102 
103 /**
104  * \brief Shift-left then add.  Returns (\p x << \p shift) + \p addend.
105  */
SHL_ADD(unsigned int x,unsigned int shift,unsigned int addend)106 __device__ __forceinline__ unsigned int SHL_ADD(
107     unsigned int x,
108     unsigned int shift,
109     unsigned int addend)
110 {
111     unsigned int ret;
112 #if CUB_PTX_ARCH >= 200
113     asm ("vshl.u32.u32.u32.clamp.add %0, %1, %2, %3;" :
114         "=r"(ret) : "r"(x), "r"(shift), "r"(addend));
115 #else
116     ret = (x << shift) + addend;
117 #endif
118     return ret;
119 }
120 
121 #ifndef DOXYGEN_SHOULD_SKIP_THIS    // Do not document
122 
123 /**
124  * Bitfield-extract.
125  */
126 template <typename UnsignedBits, int BYTE_LEN>
BFE(UnsignedBits source,unsigned int bit_start,unsigned int num_bits,Int2Type<BYTE_LEN>)127 __device__ __forceinline__ unsigned int BFE(
128     UnsignedBits            source,
129     unsigned int            bit_start,
130     unsigned int            num_bits,
131     Int2Type<BYTE_LEN>      /*byte_len*/)
132 {
133     unsigned int bits;
134 #if CUB_PTX_ARCH >= 200
135     asm ("bfe.u32 %0, %1, %2, %3;" : "=r"(bits) : "r"((unsigned int) source), "r"(bit_start), "r"(num_bits));
136 #else
137     const unsigned int MASK = (1 << num_bits) - 1;
138     bits = (source >> bit_start) & MASK;
139 #endif
140     return bits;
141 }
142 
143 
144 /**
145  * Bitfield-extract for 64-bit types.
146  */
147 template <typename UnsignedBits>
BFE(UnsignedBits source,unsigned int bit_start,unsigned int num_bits,Int2Type<8>)148 __device__ __forceinline__ unsigned int BFE(
149     UnsignedBits            source,
150     unsigned int            bit_start,
151     unsigned int            num_bits,
152     Int2Type<8>             /*byte_len*/)
153 {
154     const unsigned long long MASK = (1ull << num_bits) - 1;
155     return (source >> bit_start) & MASK;
156 }
157 
158 #endif // DOXYGEN_SHOULD_SKIP_THIS
159 
160 /**
161  * \brief Bitfield-extract.  Extracts \p num_bits from \p source starting at bit-offset \p bit_start.  The input \p source may be an 8b, 16b, 32b, or 64b unsigned integer type.
162  */
163 template <typename UnsignedBits>
BFE(UnsignedBits source,unsigned int bit_start,unsigned int num_bits)164 __device__ __forceinline__ unsigned int BFE(
165     UnsignedBits source,
166     unsigned int bit_start,
167     unsigned int num_bits)
168 {
169     return BFE(source, bit_start, num_bits, Int2Type<sizeof(UnsignedBits)>());
170 }
171 
172 
173 /**
174  * \brief Bitfield insert.  Inserts the \p num_bits least significant bits of \p y into \p x at bit-offset \p bit_start.
175  */
BFI(unsigned int & ret,unsigned int x,unsigned int y,unsigned int bit_start,unsigned int num_bits)176 __device__ __forceinline__ void BFI(
177     unsigned int &ret,
178     unsigned int x,
179     unsigned int y,
180     unsigned int bit_start,
181     unsigned int num_bits)
182 {
183 #if CUB_PTX_ARCH >= 200
184     asm ("bfi.b32 %0, %1, %2, %3, %4;" :
185         "=r"(ret) : "r"(y), "r"(x), "r"(bit_start), "r"(num_bits));
186 #else
187     x <<= bit_start;
188     unsigned int MASK_X = ((1 << num_bits) - 1) << bit_start;
189     unsigned int MASK_Y = ~MASK_X;
190     ret = (y & MASK_Y) | (x & MASK_X);
191 #endif
192 }
193 
194 
195 /**
196  * \brief Three-operand add.  Returns \p x + \p y + \p z.
197  */
IADD3(unsigned int x,unsigned int y,unsigned int z)198 __device__ __forceinline__ unsigned int IADD3(unsigned int x, unsigned int y, unsigned int z)
199 {
200 #if CUB_PTX_ARCH >= 200
201     asm ("vadd.u32.u32.u32.add %0, %1, %2, %3;" : "=r"(x) : "r"(x), "r"(y), "r"(z));
202 #else
203     x = x + y + z;
204 #endif
205     return x;
206 }
207 
208 
209 /**
210  * \brief Byte-permute. Pick four arbitrary bytes from two 32-bit registers, and reassemble them into a 32-bit destination register.  For SM2.0 or later.
211  *
212  * \par
213  * The bytes in the two source registers \p a and \p b are numbered from 0 to 7:
214  * {\p b, \p a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}. For each of the four bytes
215  * {b3, b2, b1, b0} selected in the return value, a 4-bit selector is defined within
216  * the four lower "nibbles" of \p index: {\p index } = {n7, n6, n5, n4, n3, n2, n1, n0}
217  *
218  * \par Snippet
219  * The code snippet below illustrates byte-permute.
220  * \par
221  * \code
222  * #include <cub/cub.cuh>
223  *
224  * __global__ void ExampleKernel(...)
225  * {
226  *     int a        = 0x03020100;
227  *     int b        = 0x07060504;
228  *     int index    = 0x00007531;
229  *
230  *     int selected = PRMT(a, b, index);    // 0x07050301
231  *
232  * \endcode
233  *
234  */
PRMT(unsigned int a,unsigned int b,unsigned int index)235 __device__ __forceinline__ int PRMT(unsigned int a, unsigned int b, unsigned int index)
236 {
237     int ret;
238     asm ("prmt.b32 %0, %1, %2, %3;" : "=r"(ret) : "r"(a), "r"(b), "r"(index));
239     return ret;
240 }
241 
242 #ifndef DOXYGEN_SHOULD_SKIP_THIS    // Do not document
243 
244 /**
245  * Sync-threads barrier.
246  */
BAR(int count)247 __device__ __forceinline__ void BAR(int count)
248 {
249     asm volatile("bar.sync 1, %0;" : : "r"(count));
250 }
251 
252 /**
253  * CTA barrier
254  */
CTA_SYNC()255 __device__  __forceinline__ void CTA_SYNC()
256 {
257     __syncthreads();
258 }
259 
260 
261 /**
262  * CTA barrier with predicate
263  */
CTA_SYNC_AND(int p)264 __device__  __forceinline__ int CTA_SYNC_AND(int p)
265 {
266     return __syncthreads_and(p);
267 }
268 
269 
270 /**
271  * Warp barrier
272  */
WARP_SYNC(unsigned int member_mask)273 __device__  __forceinline__ void WARP_SYNC(unsigned int member_mask)
274 {
275 #ifdef CUB_USE_COOPERATIVE_GROUPS
276     __syncwarp(member_mask);
277 #endif
278 }
279 
280 
281 /**
282  * Warp any
283  */
WARP_ANY(int predicate,unsigned int member_mask)284 __device__  __forceinline__ int WARP_ANY(int predicate, unsigned int member_mask)
285 {
286 #ifdef CUB_USE_COOPERATIVE_GROUPS
287     return __any_sync(member_mask, predicate);
288 #else
289     return ::__any(predicate);
290 #endif
291 }
292 
293 
294 /**
295  * Warp any
296  */
WARP_ALL(int predicate,unsigned int member_mask)297 __device__  __forceinline__ int WARP_ALL(int predicate, unsigned int member_mask)
298 {
299 #ifdef CUB_USE_COOPERATIVE_GROUPS
300     return __all_sync(member_mask, predicate);
301 #else
302     return ::__all(predicate);
303 #endif
304 }
305 
306 
307 /**
308  * Warp ballot
309  */
WARP_BALLOT(int predicate,unsigned int member_mask)310 __device__  __forceinline__ int WARP_BALLOT(int predicate, unsigned int member_mask)
311 {
312 #ifdef CUB_USE_COOPERATIVE_GROUPS
313     return __ballot_sync(member_mask, predicate);
314 #else
315     return __ballot(predicate);
316 #endif
317 }
318 
319 /**
320  * Warp synchronous shfl_up
321  */
322 __device__ __forceinline__
SHFL_UP_SYNC(unsigned int word,int src_offset,int flags,unsigned int member_mask)323 unsigned int SHFL_UP_SYNC(unsigned int word, int src_offset, int flags, unsigned int member_mask)
324 {
325 #ifdef CUB_USE_COOPERATIVE_GROUPS
326     asm volatile("shfl.sync.up.b32 %0, %1, %2, %3, %4;"
327         : "=r"(word) : "r"(word), "r"(src_offset), "r"(flags), "r"(member_mask));
328 #else
329     asm volatile("shfl.up.b32 %0, %1, %2, %3;"
330         : "=r"(word) : "r"(word), "r"(src_offset), "r"(flags));
331 #endif
332     return word;
333 }
334 
335 /**
336  * Warp synchronous shfl_down
337  */
338 __device__ __forceinline__
SHFL_DOWN_SYNC(unsigned int word,int src_offset,int flags,unsigned int member_mask)339 unsigned int SHFL_DOWN_SYNC(unsigned int word, int src_offset, int flags, unsigned int member_mask)
340 {
341 #ifdef CUB_USE_COOPERATIVE_GROUPS
342     asm volatile("shfl.sync.down.b32 %0, %1, %2, %3, %4;"
343         : "=r"(word) : "r"(word), "r"(src_offset), "r"(flags), "r"(member_mask));
344 #else
345     asm volatile("shfl.down.b32 %0, %1, %2, %3;"
346         : "=r"(word) : "r"(word), "r"(src_offset), "r"(flags));
347 #endif
348     return word;
349 }
350 
351 /**
352  * Warp synchronous shfl_idx
353  */
354 __device__ __forceinline__
SHFL_IDX_SYNC(unsigned int word,int src_lane,int flags,unsigned int member_mask)355 unsigned int SHFL_IDX_SYNC(unsigned int word, int src_lane, int flags, unsigned int member_mask)
356 {
357 #ifdef CUB_USE_COOPERATIVE_GROUPS
358     asm volatile("shfl.sync.idx.b32 %0, %1, %2, %3, %4;"
359         : "=r"(word) : "r"(word), "r"(src_lane), "r"(flags), "r"(member_mask));
360 #else
361     asm volatile("shfl.idx.b32 %0, %1, %2, %3;"
362         : "=r"(word) : "r"(word), "r"(src_lane), "r"(flags));
363 #endif
364     return word;
365 }
366 
367 /**
368  * Floating point multiply. (Mantissa LSB rounds towards zero.)
369  */
FMUL_RZ(float a,float b)370 __device__ __forceinline__ float FMUL_RZ(float a, float b)
371 {
372     float d;
373     asm ("mul.rz.f32 %0, %1, %2;" : "=f"(d) : "f"(a), "f"(b));
374     return d;
375 }
376 
377 
378 /**
379  * Floating point multiply-add. (Mantissa LSB rounds towards zero.)
380  */
FFMA_RZ(float a,float b,float c)381 __device__ __forceinline__ float FFMA_RZ(float a, float b, float c)
382 {
383     float d;
384     asm ("fma.rz.f32 %0, %1, %2, %3;" : "=f"(d) : "f"(a), "f"(b), "f"(c));
385     return d;
386 }
387 
388 #endif // DOXYGEN_SHOULD_SKIP_THIS
389 
390 /**
391  * \brief Terminates the calling thread
392  */
ThreadExit()393 __device__ __forceinline__ void ThreadExit() {
394     asm volatile("exit;");
395 }
396 
397 
398 /**
399  * \brief  Abort execution and generate an interrupt to the host CPU
400  */
ThreadTrap()401 __device__ __forceinline__ void ThreadTrap() {
402     asm volatile("trap;");
403 }
404 
405 
406 /**
407  * \brief Returns the row-major linear thread identifier for a multidimensional thread block
408  */
RowMajorTid(int block_dim_x,int block_dim_y,int block_dim_z)409 __device__ __forceinline__ int RowMajorTid(int block_dim_x, int block_dim_y, int block_dim_z)
410 {
411     return ((block_dim_z == 1) ? 0 : (threadIdx.z * block_dim_x * block_dim_y)) +
412             ((block_dim_y == 1) ? 0 : (threadIdx.y * block_dim_x)) +
413             threadIdx.x;
414 }
415 
416 
417 /**
418  * \brief Returns the warp lane ID of the calling thread
419  */
LaneId()420 __device__ __forceinline__ unsigned int LaneId()
421 {
422     unsigned int ret;
423     asm ("mov.u32 %0, %%laneid;" : "=r"(ret) );
424     return ret;
425 }
426 
427 
428 /**
429  * \brief Returns the warp ID of the calling thread.  Warp ID is guaranteed to be unique among warps, but may not correspond to a zero-based ranking within the thread block.
430  */
WarpId()431 __device__ __forceinline__ unsigned int WarpId()
432 {
433     unsigned int ret;
434     asm ("mov.u32 %0, %%warpid;" : "=r"(ret) );
435     return ret;
436 }
437 
438 /**
439  * \brief Returns the warp lane mask of all lanes less than the calling thread
440  */
LaneMaskLt()441 __device__ __forceinline__ unsigned int LaneMaskLt()
442 {
443     unsigned int ret;
444     asm ("mov.u32 %0, %%lanemask_lt;" : "=r"(ret) );
445     return ret;
446 }
447 
448 /**
449  * \brief Returns the warp lane mask of all lanes less than or equal to the calling thread
450  */
LaneMaskLe()451 __device__ __forceinline__ unsigned int LaneMaskLe()
452 {
453     unsigned int ret;
454     asm ("mov.u32 %0, %%lanemask_le;" : "=r"(ret) );
455     return ret;
456 }
457 
458 /**
459  * \brief Returns the warp lane mask of all lanes greater than the calling thread
460  */
LaneMaskGt()461 __device__ __forceinline__ unsigned int LaneMaskGt()
462 {
463     unsigned int ret;
464     asm ("mov.u32 %0, %%lanemask_gt;" : "=r"(ret) );
465     return ret;
466 }
467 
468 /**
469  * \brief Returns the warp lane mask of all lanes greater than or equal to the calling thread
470  */
LaneMaskGe()471 __device__ __forceinline__ unsigned int LaneMaskGe()
472 {
473     unsigned int ret;
474     asm ("mov.u32 %0, %%lanemask_ge;" : "=r"(ret) );
475     return ret;
476 }
477 
478 /** @} */       // end group UtilPtx
479 
480 
481 
482 
483 /**
484  * \brief Shuffle-up for any data type.  Each <em>warp-lane<sub>i</sub></em> obtains the value \p input contributed by <em>warp-lane</em><sub><em>i</em>-<tt>src_offset</tt></sub>.  For thread lanes \e i < src_offset, the thread's own \p input is returned to the thread. ![](shfl_up_logo.png)
485  * \ingroup WarpModule
486  *
487  * \tparam LOGICAL_WARP_THREADS     The number of threads per "logical" warp.  Must be a power-of-two <= 32.
488  * \tparam T                        <b>[inferred]</b> The input/output element type
489  *
490  * \par
491  * - Available only for SM3.0 or newer
492  *
493  * \par Snippet
494  * The code snippet below illustrates each thread obtaining a \p double value from the
495  * predecessor of its predecessor.
496  * \par
497  * \code
498  * #include <cub/cub.cuh>   // or equivalently <cub/util_ptx.cuh>
499  *
500  * __global__ void ExampleKernel(...)
501  * {
502  *     // Obtain one input item per thread
503  *     double thread_data = ...
504  *
505  *     // Obtain item from two ranks below
506  *     double peer_data = ShuffleUp<32>(thread_data, 2, 0, 0xffffffff);
507  *
508  * \endcode
509  * \par
510  * Suppose the set of input \p thread_data across the first warp of threads is <tt>{1.0, 2.0, 3.0, 4.0, 5.0, ..., 32.0}</tt>.
511  * The corresponding output \p peer_data will be <tt>{1.0, 2.0, 1.0, 2.0, 3.0, ..., 30.0}</tt>.
512  *
513  */
514 template <
515     int LOGICAL_WARP_THREADS,   ///< Number of threads per logical warp
516     typename T>
ShuffleUp(T input,int src_offset,int first_thread,unsigned int member_mask)517 __device__ __forceinline__ T ShuffleUp(
518     T               input,              ///< [in] The value to broadcast
519     int             src_offset,         ///< [in] The relative down-offset of the peer to read from
520     int             first_thread,       ///< [in] Index of first lane in logical warp (typically 0)
521     unsigned int    member_mask)        ///< [in] 32-bit mask of participating warp lanes
522 {
523     /// The 5-bit SHFL mask for logically splitting warps into sub-segments starts 8-bits up
524     enum {
525         SHFL_C = (32 - LOGICAL_WARP_THREADS) << 8
526     };
527 
528     typedef typename UnitWord<T>::ShuffleWord ShuffleWord;
529 
530     const int       WORDS           = (sizeof(T) + sizeof(ShuffleWord) - 1) / sizeof(ShuffleWord);
531 
532     T               output;
533     ShuffleWord     *output_alias   = reinterpret_cast<ShuffleWord *>(&output);
534     ShuffleWord     *input_alias    = reinterpret_cast<ShuffleWord *>(&input);
535 
536     unsigned int shuffle_word;
537     shuffle_word = SHFL_UP_SYNC((unsigned int)input_alias[0], src_offset, first_thread | SHFL_C, member_mask);
538     output_alias[0] = shuffle_word;
539 
540     #pragma unroll
541     for (int WORD = 1; WORD < WORDS; ++WORD)
542     {
543         shuffle_word       = SHFL_UP_SYNC((unsigned int)input_alias[WORD], src_offset, first_thread | SHFL_C, member_mask);
544         output_alias[WORD] = shuffle_word;
545     }
546 
547     return output;
548 }
549 
550 
551 /**
552  * \brief Shuffle-down for any data type.  Each <em>warp-lane<sub>i</sub></em> obtains the value \p input contributed by <em>warp-lane</em><sub><em>i</em>+<tt>src_offset</tt></sub>.  For thread lanes \e i >= WARP_THREADS, the thread's own \p input is returned to the thread.  ![](shfl_down_logo.png)
553  * \ingroup WarpModule
554  *
555  * \tparam LOGICAL_WARP_THREADS     The number of threads per "logical" warp.  Must be a power-of-two <= 32.
556  * \tparam T                        <b>[inferred]</b> The input/output element type
557  *
558  * \par
559  * - Available only for SM3.0 or newer
560  *
561  * \par Snippet
562  * The code snippet below illustrates each thread obtaining a \p double value from the
563  * successor of its successor.
564  * \par
565  * \code
566  * #include <cub/cub.cuh>   // or equivalently <cub/util_ptx.cuh>
567  *
568  * __global__ void ExampleKernel(...)
569  * {
570  *     // Obtain one input item per thread
571  *     double thread_data = ...
572  *
573  *     // Obtain item from two ranks below
574  *     double peer_data = ShuffleDown<32>(thread_data, 2, 31, 0xffffffff);
575  *
576  * \endcode
577  * \par
578  * Suppose the set of input \p thread_data across the first warp of threads is <tt>{1.0, 2.0, 3.0, 4.0, 5.0, ..., 32.0}</tt>.
579  * The corresponding output \p peer_data will be <tt>{3.0, 4.0, 5.0, 6.0, 7.0, ..., 32.0}</tt>.
580  *
581  */
582 template <
583     int LOGICAL_WARP_THREADS,   ///< Number of threads per logical warp
584     typename T>
ShuffleDown(T input,int src_offset,int last_thread,unsigned int member_mask)585 __device__ __forceinline__ T ShuffleDown(
586     T               input,              ///< [in] The value to broadcast
587     int             src_offset,         ///< [in] The relative up-offset of the peer to read from
588     int             last_thread,        ///< [in] Index of last thread in logical warp (typically 31 for a 32-thread warp)
589     unsigned int    member_mask)        ///< [in] 32-bit mask of participating warp lanes
590 {
591     /// The 5-bit SHFL mask for logically splitting warps into sub-segments starts 8-bits up
592     enum {
593         SHFL_C = (32 - LOGICAL_WARP_THREADS) << 8
594     };
595 
596     typedef typename UnitWord<T>::ShuffleWord ShuffleWord;
597 
598     const int       WORDS           = (sizeof(T) + sizeof(ShuffleWord) - 1) / sizeof(ShuffleWord);
599 
600     T               output;
601     ShuffleWord     *output_alias   = reinterpret_cast<ShuffleWord *>(&output);
602     ShuffleWord     *input_alias    = reinterpret_cast<ShuffleWord *>(&input);
603 
604     unsigned int shuffle_word;
605     shuffle_word    = SHFL_DOWN_SYNC((unsigned int)input_alias[0], src_offset, last_thread | SHFL_C, member_mask);
606     output_alias[0] = shuffle_word;
607 
608     #pragma unroll
609     for (int WORD = 1; WORD < WORDS; ++WORD)
610     {
611         shuffle_word       = SHFL_DOWN_SYNC((unsigned int)input_alias[WORD], src_offset, last_thread | SHFL_C, member_mask);
612         output_alias[WORD] = shuffle_word;
613     }
614 
615     return output;
616 }
617 
618 
619 /**
620  * \brief Shuffle-broadcast for any data type.  Each <em>warp-lane<sub>i</sub></em> obtains the value \p input
621  * contributed by <em>warp-lane</em><sub><tt>src_lane</tt></sub>.  For \p src_lane < 0 or \p src_lane >= WARP_THREADS,
622  * then the thread's own \p input is returned to the thread. ![](shfl_broadcast_logo.png)
623  *
624  * \tparam LOGICAL_WARP_THREADS     The number of threads per "logical" warp.  Must be a power-of-two <= 32.
625  * \tparam T                        <b>[inferred]</b> The input/output element type
626  *
627  * \ingroup WarpModule
628  *
629  * \par
630  * - Available only for SM3.0 or newer
631  *
632  * \par Snippet
633  * The code snippet below illustrates each thread obtaining a \p double value from <em>warp-lane</em><sub>0</sub>.
634  *
635  * \par
636  * \code
637  * #include <cub/cub.cuh>   // or equivalently <cub/util_ptx.cuh>
638  *
639  * __global__ void ExampleKernel(...)
640  * {
641  *     // Obtain one input item per thread
642  *     double thread_data = ...
643  *
644  *     // Obtain item from thread 0
645  *     double peer_data = ShuffleIndex<32>(thread_data, 0, 0xffffffff);
646  *
647  * \endcode
648  * \par
649  * Suppose the set of input \p thread_data across the first warp of threads is <tt>{1.0, 2.0, 3.0, 4.0, 5.0, ..., 32.0}</tt>.
650  * The corresponding output \p peer_data will be <tt>{1.0, 1.0, 1.0, 1.0, 1.0, ..., 1.0}</tt>.
651  *
652  */
653 template <
654     int LOGICAL_WARP_THREADS,   ///< Number of threads per logical warp
655     typename T>
ShuffleIndex(T input,int src_lane,unsigned int member_mask)656 __device__ __forceinline__ T ShuffleIndex(
657     T               input,                  ///< [in] The value to broadcast
658     int             src_lane,               ///< [in] Which warp lane is to do the broadcasting
659     unsigned int    member_mask)            ///< [in] 32-bit mask of participating warp lanes
660 {
661     /// The 5-bit SHFL mask for logically splitting warps into sub-segments starts 8-bits up
662     enum {
663         SHFL_C = ((32 - LOGICAL_WARP_THREADS) << 8) | (LOGICAL_WARP_THREADS - 1)
664     };
665 
666     typedef typename UnitWord<T>::ShuffleWord ShuffleWord;
667 
668     const int       WORDS           = (sizeof(T) + sizeof(ShuffleWord) - 1) / sizeof(ShuffleWord);
669 
670     T               output;
671     ShuffleWord     *output_alias   = reinterpret_cast<ShuffleWord *>(&output);
672     ShuffleWord     *input_alias    = reinterpret_cast<ShuffleWord *>(&input);
673 
674     unsigned int shuffle_word;
675     shuffle_word = SHFL_IDX_SYNC((unsigned int)input_alias[0],
676                                  src_lane,
677                                  SHFL_C,
678                                  member_mask);
679 
680     output_alias[0] = shuffle_word;
681 
682     #pragma unroll
683     for (int WORD = 1; WORD < WORDS; ++WORD)
684     {
685         shuffle_word = SHFL_IDX_SYNC((unsigned int)input_alias[WORD],
686                                      src_lane,
687                                      SHFL_C,
688                                      member_mask);
689 
690         output_alias[WORD] = shuffle_word;
691     }
692 
693     return output;
694 }
695 
696 
697 
698 /**
699  * Compute a 32b mask of threads having the same least-significant
700  * LABEL_BITS of \p label as the calling thread.
701  */
702 template <int LABEL_BITS>
MatchAny(unsigned int label)703 inline __device__ unsigned int MatchAny(unsigned int label)
704 {
705     unsigned int retval;
706 
707     // Extract masks of common threads for each bit
708     #pragma unroll
709     for (int BIT = 0; BIT < LABEL_BITS; ++BIT)
710     {
711         unsigned int mask;
712         unsigned int current_bit = 1 << BIT;
713         asm ("{\n"
714             "    .reg .pred p;\n"
715             "    and.b32 %0, %1, %2;"
716             "    setp.eq.u32 p, %0, %2;\n"
717 #ifdef CUB_USE_COOPERATIVE_GROUPS
718             "    vote.ballot.sync.b32 %0, p, 0xffffffff;\n"
719 #else
720             "    vote.ballot.b32 %0, p;\n"
721 #endif
722             "    @!p not.b32 %0, %0;\n"
723             "}\n" : "=r"(mask) : "r"(label), "r"(current_bit));
724 
725         // Remove peers who differ
726         retval = (BIT == 0) ? mask : retval & mask;
727     }
728 
729     return retval;
730 
731 //  // VOLTA match
732 //    unsigned int retval;
733 //    asm ("{\n"
734 //         "    match.any.sync.b32 %0, %1, 0xffffffff;\n"
735 //         "}\n" : "=r"(retval) : "r"(label));
736 //    return retval;
737 
738 }
739 
740 
741 
742 
743 
744 
745 
746 
747 
748 
749 
750 
751 
752 
753 
754 
755 
756 
757 }               // CUB namespace
758 CUB_NS_POSTFIX  // Optional outer namespace(s)
759