1 /******************************************************************************
2 * Copyright (c) 2011, Duane Merrill. All rights reserved.
3 * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 * * Redistributions of source code must retain the above copyright
8 * notice, this list of conditions and the following disclaimer.
9 * * Redistributions in binary form must reproduce the above copyright
10 * notice, this list of conditions and the following disclaimer in the
11 * documentation and/or other materials provided with the distribution.
12 * * Neither the name of the NVIDIA CORPORATION nor the
13 * names of its contributors may be used to endorse or promote products
14 * derived from this software without specific prior written permission.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
17 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19 * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
20 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 *
27 ******************************************************************************/
28
29 /**
30 * \file
31 * PTX intrinsics
32 */
33
34
35 #pragma once
36
37 #include "util_type.cuh"
38 #include "util_arch.cuh"
39 #include "util_namespace.cuh"
40 #include "util_debug.cuh"
41
42
43 /// Optional outer namespace(s)
44 CUB_NS_PREFIX
45
46 /// CUB namespace
47 namespace cub {
48
49
50 /**
51 * \addtogroup UtilPtx
52 * @{
53 */
54
55
56 /******************************************************************************
57 * PTX helper macros
58 ******************************************************************************/
59
60 #ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document
61
62 /**
63 * Register modifier for pointer-types (for inlining PTX assembly)
64 */
65 #if defined(_WIN64) || defined(__LP64__)
66 #define __CUB_LP64__ 1
67 // 64-bit register modifier for inlined asm
68 #define _CUB_ASM_PTR_ "l"
69 #define _CUB_ASM_PTR_SIZE_ "u64"
70 #else
71 #define __CUB_LP64__ 0
72 // 32-bit register modifier for inlined asm
73 #define _CUB_ASM_PTR_ "r"
74 #define _CUB_ASM_PTR_SIZE_ "u32"
75 #endif
76
77 #endif // DOXYGEN_SHOULD_SKIP_THIS
78
79
80 /******************************************************************************
81 * Inlined PTX intrinsics
82 ******************************************************************************/
83
84 /**
85 * \brief Shift-right then add. Returns (\p x >> \p shift) + \p addend.
86 */
SHR_ADD(unsigned int x,unsigned int shift,unsigned int addend)87 __device__ __forceinline__ unsigned int SHR_ADD(
88 unsigned int x,
89 unsigned int shift,
90 unsigned int addend)
91 {
92 unsigned int ret;
93 #if CUB_PTX_ARCH >= 200
94 asm ("vshr.u32.u32.u32.clamp.add %0, %1, %2, %3;" :
95 "=r"(ret) : "r"(x), "r"(shift), "r"(addend));
96 #else
97 ret = (x >> shift) + addend;
98 #endif
99 return ret;
100 }
101
102
103 /**
104 * \brief Shift-left then add. Returns (\p x << \p shift) + \p addend.
105 */
SHL_ADD(unsigned int x,unsigned int shift,unsigned int addend)106 __device__ __forceinline__ unsigned int SHL_ADD(
107 unsigned int x,
108 unsigned int shift,
109 unsigned int addend)
110 {
111 unsigned int ret;
112 #if CUB_PTX_ARCH >= 200
113 asm ("vshl.u32.u32.u32.clamp.add %0, %1, %2, %3;" :
114 "=r"(ret) : "r"(x), "r"(shift), "r"(addend));
115 #else
116 ret = (x << shift) + addend;
117 #endif
118 return ret;
119 }
120
121 #ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document
122
123 /**
124 * Bitfield-extract.
125 */
126 template <typename UnsignedBits, int BYTE_LEN>
BFE(UnsignedBits source,unsigned int bit_start,unsigned int num_bits,Int2Type<BYTE_LEN>)127 __device__ __forceinline__ unsigned int BFE(
128 UnsignedBits source,
129 unsigned int bit_start,
130 unsigned int num_bits,
131 Int2Type<BYTE_LEN> /*byte_len*/)
132 {
133 unsigned int bits;
134 #if CUB_PTX_ARCH >= 200
135 asm ("bfe.u32 %0, %1, %2, %3;" : "=r"(bits) : "r"((unsigned int) source), "r"(bit_start), "r"(num_bits));
136 #else
137 const unsigned int MASK = (1 << num_bits) - 1;
138 bits = (source >> bit_start) & MASK;
139 #endif
140 return bits;
141 }
142
143
144 /**
145 * Bitfield-extract for 64-bit types.
146 */
147 template <typename UnsignedBits>
BFE(UnsignedBits source,unsigned int bit_start,unsigned int num_bits,Int2Type<8>)148 __device__ __forceinline__ unsigned int BFE(
149 UnsignedBits source,
150 unsigned int bit_start,
151 unsigned int num_bits,
152 Int2Type<8> /*byte_len*/)
153 {
154 const unsigned long long MASK = (1ull << num_bits) - 1;
155 return (source >> bit_start) & MASK;
156 }
157
158 #endif // DOXYGEN_SHOULD_SKIP_THIS
159
160 /**
161 * \brief Bitfield-extract. Extracts \p num_bits from \p source starting at bit-offset \p bit_start. The input \p source may be an 8b, 16b, 32b, or 64b unsigned integer type.
162 */
163 template <typename UnsignedBits>
BFE(UnsignedBits source,unsigned int bit_start,unsigned int num_bits)164 __device__ __forceinline__ unsigned int BFE(
165 UnsignedBits source,
166 unsigned int bit_start,
167 unsigned int num_bits)
168 {
169 return BFE(source, bit_start, num_bits, Int2Type<sizeof(UnsignedBits)>());
170 }
171
172
173 /**
174 * \brief Bitfield insert. Inserts the \p num_bits least significant bits of \p y into \p x at bit-offset \p bit_start.
175 */
BFI(unsigned int & ret,unsigned int x,unsigned int y,unsigned int bit_start,unsigned int num_bits)176 __device__ __forceinline__ void BFI(
177 unsigned int &ret,
178 unsigned int x,
179 unsigned int y,
180 unsigned int bit_start,
181 unsigned int num_bits)
182 {
183 #if CUB_PTX_ARCH >= 200
184 asm ("bfi.b32 %0, %1, %2, %3, %4;" :
185 "=r"(ret) : "r"(y), "r"(x), "r"(bit_start), "r"(num_bits));
186 #else
187 x <<= bit_start;
188 unsigned int MASK_X = ((1 << num_bits) - 1) << bit_start;
189 unsigned int MASK_Y = ~MASK_X;
190 ret = (y & MASK_Y) | (x & MASK_X);
191 #endif
192 }
193
194
195 /**
196 * \brief Three-operand add. Returns \p x + \p y + \p z.
197 */
IADD3(unsigned int x,unsigned int y,unsigned int z)198 __device__ __forceinline__ unsigned int IADD3(unsigned int x, unsigned int y, unsigned int z)
199 {
200 #if CUB_PTX_ARCH >= 200
201 asm ("vadd.u32.u32.u32.add %0, %1, %2, %3;" : "=r"(x) : "r"(x), "r"(y), "r"(z));
202 #else
203 x = x + y + z;
204 #endif
205 return x;
206 }
207
208
209 /**
210 * \brief Byte-permute. Pick four arbitrary bytes from two 32-bit registers, and reassemble them into a 32-bit destination register. For SM2.0 or later.
211 *
212 * \par
213 * The bytes in the two source registers \p a and \p b are numbered from 0 to 7:
214 * {\p b, \p a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}. For each of the four bytes
215 * {b3, b2, b1, b0} selected in the return value, a 4-bit selector is defined within
216 * the four lower "nibbles" of \p index: {\p index } = {n7, n6, n5, n4, n3, n2, n1, n0}
217 *
218 * \par Snippet
219 * The code snippet below illustrates byte-permute.
220 * \par
221 * \code
222 * #include <cub/cub.cuh>
223 *
224 * __global__ void ExampleKernel(...)
225 * {
226 * int a = 0x03020100;
227 * int b = 0x07060504;
228 * int index = 0x00007531;
229 *
230 * int selected = PRMT(a, b, index); // 0x07050301
231 *
232 * \endcode
233 *
234 */
PRMT(unsigned int a,unsigned int b,unsigned int index)235 __device__ __forceinline__ int PRMT(unsigned int a, unsigned int b, unsigned int index)
236 {
237 int ret;
238 asm ("prmt.b32 %0, %1, %2, %3;" : "=r"(ret) : "r"(a), "r"(b), "r"(index));
239 return ret;
240 }
241
242 #ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document
243
244 /**
245 * Sync-threads barrier.
246 */
BAR(int count)247 __device__ __forceinline__ void BAR(int count)
248 {
249 asm volatile("bar.sync 1, %0;" : : "r"(count));
250 }
251
252 /**
253 * CTA barrier
254 */
CTA_SYNC()255 __device__ __forceinline__ void CTA_SYNC()
256 {
257 __syncthreads();
258 }
259
260
261 /**
262 * CTA barrier with predicate
263 */
CTA_SYNC_AND(int p)264 __device__ __forceinline__ int CTA_SYNC_AND(int p)
265 {
266 return __syncthreads_and(p);
267 }
268
269
270 /**
271 * Warp barrier
272 */
WARP_SYNC(unsigned int member_mask)273 __device__ __forceinline__ void WARP_SYNC(unsigned int member_mask)
274 {
275 #ifdef CUB_USE_COOPERATIVE_GROUPS
276 __syncwarp(member_mask);
277 #endif
278 }
279
280
281 /**
282 * Warp any
283 */
WARP_ANY(int predicate,unsigned int member_mask)284 __device__ __forceinline__ int WARP_ANY(int predicate, unsigned int member_mask)
285 {
286 #ifdef CUB_USE_COOPERATIVE_GROUPS
287 return __any_sync(member_mask, predicate);
288 #else
289 return ::__any(predicate);
290 #endif
291 }
292
293
294 /**
295 * Warp any
296 */
WARP_ALL(int predicate,unsigned int member_mask)297 __device__ __forceinline__ int WARP_ALL(int predicate, unsigned int member_mask)
298 {
299 #ifdef CUB_USE_COOPERATIVE_GROUPS
300 return __all_sync(member_mask, predicate);
301 #else
302 return ::__all(predicate);
303 #endif
304 }
305
306
307 /**
308 * Warp ballot
309 */
WARP_BALLOT(int predicate,unsigned int member_mask)310 __device__ __forceinline__ int WARP_BALLOT(int predicate, unsigned int member_mask)
311 {
312 #ifdef CUB_USE_COOPERATIVE_GROUPS
313 return __ballot_sync(member_mask, predicate);
314 #else
315 return __ballot(predicate);
316 #endif
317 }
318
319 /**
320 * Warp synchronous shfl_up
321 */
322 __device__ __forceinline__
SHFL_UP_SYNC(unsigned int word,int src_offset,int flags,unsigned int member_mask)323 unsigned int SHFL_UP_SYNC(unsigned int word, int src_offset, int flags, unsigned int member_mask)
324 {
325 #ifdef CUB_USE_COOPERATIVE_GROUPS
326 asm volatile("shfl.sync.up.b32 %0, %1, %2, %3, %4;"
327 : "=r"(word) : "r"(word), "r"(src_offset), "r"(flags), "r"(member_mask));
328 #else
329 asm volatile("shfl.up.b32 %0, %1, %2, %3;"
330 : "=r"(word) : "r"(word), "r"(src_offset), "r"(flags));
331 #endif
332 return word;
333 }
334
335 /**
336 * Warp synchronous shfl_down
337 */
338 __device__ __forceinline__
SHFL_DOWN_SYNC(unsigned int word,int src_offset,int flags,unsigned int member_mask)339 unsigned int SHFL_DOWN_SYNC(unsigned int word, int src_offset, int flags, unsigned int member_mask)
340 {
341 #ifdef CUB_USE_COOPERATIVE_GROUPS
342 asm volatile("shfl.sync.down.b32 %0, %1, %2, %3, %4;"
343 : "=r"(word) : "r"(word), "r"(src_offset), "r"(flags), "r"(member_mask));
344 #else
345 asm volatile("shfl.down.b32 %0, %1, %2, %3;"
346 : "=r"(word) : "r"(word), "r"(src_offset), "r"(flags));
347 #endif
348 return word;
349 }
350
351 /**
352 * Warp synchronous shfl_idx
353 */
354 __device__ __forceinline__
SHFL_IDX_SYNC(unsigned int word,int src_lane,int flags,unsigned int member_mask)355 unsigned int SHFL_IDX_SYNC(unsigned int word, int src_lane, int flags, unsigned int member_mask)
356 {
357 #ifdef CUB_USE_COOPERATIVE_GROUPS
358 asm volatile("shfl.sync.idx.b32 %0, %1, %2, %3, %4;"
359 : "=r"(word) : "r"(word), "r"(src_lane), "r"(flags), "r"(member_mask));
360 #else
361 asm volatile("shfl.idx.b32 %0, %1, %2, %3;"
362 : "=r"(word) : "r"(word), "r"(src_lane), "r"(flags));
363 #endif
364 return word;
365 }
366
367 /**
368 * Floating point multiply. (Mantissa LSB rounds towards zero.)
369 */
FMUL_RZ(float a,float b)370 __device__ __forceinline__ float FMUL_RZ(float a, float b)
371 {
372 float d;
373 asm ("mul.rz.f32 %0, %1, %2;" : "=f"(d) : "f"(a), "f"(b));
374 return d;
375 }
376
377
378 /**
379 * Floating point multiply-add. (Mantissa LSB rounds towards zero.)
380 */
FFMA_RZ(float a,float b,float c)381 __device__ __forceinline__ float FFMA_RZ(float a, float b, float c)
382 {
383 float d;
384 asm ("fma.rz.f32 %0, %1, %2, %3;" : "=f"(d) : "f"(a), "f"(b), "f"(c));
385 return d;
386 }
387
388 #endif // DOXYGEN_SHOULD_SKIP_THIS
389
390 /**
391 * \brief Terminates the calling thread
392 */
ThreadExit()393 __device__ __forceinline__ void ThreadExit() {
394 asm volatile("exit;");
395 }
396
397
398 /**
399 * \brief Abort execution and generate an interrupt to the host CPU
400 */
ThreadTrap()401 __device__ __forceinline__ void ThreadTrap() {
402 asm volatile("trap;");
403 }
404
405
406 /**
407 * \brief Returns the row-major linear thread identifier for a multidimensional thread block
408 */
RowMajorTid(int block_dim_x,int block_dim_y,int block_dim_z)409 __device__ __forceinline__ int RowMajorTid(int block_dim_x, int block_dim_y, int block_dim_z)
410 {
411 return ((block_dim_z == 1) ? 0 : (threadIdx.z * block_dim_x * block_dim_y)) +
412 ((block_dim_y == 1) ? 0 : (threadIdx.y * block_dim_x)) +
413 threadIdx.x;
414 }
415
416
417 /**
418 * \brief Returns the warp lane ID of the calling thread
419 */
LaneId()420 __device__ __forceinline__ unsigned int LaneId()
421 {
422 unsigned int ret;
423 asm ("mov.u32 %0, %%laneid;" : "=r"(ret) );
424 return ret;
425 }
426
427
428 /**
429 * \brief Returns the warp ID of the calling thread. Warp ID is guaranteed to be unique among warps, but may not correspond to a zero-based ranking within the thread block.
430 */
WarpId()431 __device__ __forceinline__ unsigned int WarpId()
432 {
433 unsigned int ret;
434 asm ("mov.u32 %0, %%warpid;" : "=r"(ret) );
435 return ret;
436 }
437
438 /**
439 * \brief Returns the warp lane mask of all lanes less than the calling thread
440 */
LaneMaskLt()441 __device__ __forceinline__ unsigned int LaneMaskLt()
442 {
443 unsigned int ret;
444 asm ("mov.u32 %0, %%lanemask_lt;" : "=r"(ret) );
445 return ret;
446 }
447
448 /**
449 * \brief Returns the warp lane mask of all lanes less than or equal to the calling thread
450 */
LaneMaskLe()451 __device__ __forceinline__ unsigned int LaneMaskLe()
452 {
453 unsigned int ret;
454 asm ("mov.u32 %0, %%lanemask_le;" : "=r"(ret) );
455 return ret;
456 }
457
458 /**
459 * \brief Returns the warp lane mask of all lanes greater than the calling thread
460 */
LaneMaskGt()461 __device__ __forceinline__ unsigned int LaneMaskGt()
462 {
463 unsigned int ret;
464 asm ("mov.u32 %0, %%lanemask_gt;" : "=r"(ret) );
465 return ret;
466 }
467
468 /**
469 * \brief Returns the warp lane mask of all lanes greater than or equal to the calling thread
470 */
LaneMaskGe()471 __device__ __forceinline__ unsigned int LaneMaskGe()
472 {
473 unsigned int ret;
474 asm ("mov.u32 %0, %%lanemask_ge;" : "=r"(ret) );
475 return ret;
476 }
477
478 /** @} */ // end group UtilPtx
479
480
481
482
483 /**
484 * \brief Shuffle-up for any data type. Each <em>warp-lane<sub>i</sub></em> obtains the value \p input contributed by <em>warp-lane</em><sub><em>i</em>-<tt>src_offset</tt></sub>. For thread lanes \e i < src_offset, the thread's own \p input is returned to the thread. ![](shfl_up_logo.png)
485 * \ingroup WarpModule
486 *
487 * \tparam LOGICAL_WARP_THREADS The number of threads per "logical" warp. Must be a power-of-two <= 32.
488 * \tparam T <b>[inferred]</b> The input/output element type
489 *
490 * \par
491 * - Available only for SM3.0 or newer
492 *
493 * \par Snippet
494 * The code snippet below illustrates each thread obtaining a \p double value from the
495 * predecessor of its predecessor.
496 * \par
497 * \code
498 * #include <cub/cub.cuh> // or equivalently <cub/util_ptx.cuh>
499 *
500 * __global__ void ExampleKernel(...)
501 * {
502 * // Obtain one input item per thread
503 * double thread_data = ...
504 *
505 * // Obtain item from two ranks below
506 * double peer_data = ShuffleUp<32>(thread_data, 2, 0, 0xffffffff);
507 *
508 * \endcode
509 * \par
510 * Suppose the set of input \p thread_data across the first warp of threads is <tt>{1.0, 2.0, 3.0, 4.0, 5.0, ..., 32.0}</tt>.
511 * The corresponding output \p peer_data will be <tt>{1.0, 2.0, 1.0, 2.0, 3.0, ..., 30.0}</tt>.
512 *
513 */
514 template <
515 int LOGICAL_WARP_THREADS, ///< Number of threads per logical warp
516 typename T>
ShuffleUp(T input,int src_offset,int first_thread,unsigned int member_mask)517 __device__ __forceinline__ T ShuffleUp(
518 T input, ///< [in] The value to broadcast
519 int src_offset, ///< [in] The relative down-offset of the peer to read from
520 int first_thread, ///< [in] Index of first lane in logical warp (typically 0)
521 unsigned int member_mask) ///< [in] 32-bit mask of participating warp lanes
522 {
523 /// The 5-bit SHFL mask for logically splitting warps into sub-segments starts 8-bits up
524 enum {
525 SHFL_C = (32 - LOGICAL_WARP_THREADS) << 8
526 };
527
528 typedef typename UnitWord<T>::ShuffleWord ShuffleWord;
529
530 const int WORDS = (sizeof(T) + sizeof(ShuffleWord) - 1) / sizeof(ShuffleWord);
531
532 T output;
533 ShuffleWord *output_alias = reinterpret_cast<ShuffleWord *>(&output);
534 ShuffleWord *input_alias = reinterpret_cast<ShuffleWord *>(&input);
535
536 unsigned int shuffle_word;
537 shuffle_word = SHFL_UP_SYNC((unsigned int)input_alias[0], src_offset, first_thread | SHFL_C, member_mask);
538 output_alias[0] = shuffle_word;
539
540 #pragma unroll
541 for (int WORD = 1; WORD < WORDS; ++WORD)
542 {
543 shuffle_word = SHFL_UP_SYNC((unsigned int)input_alias[WORD], src_offset, first_thread | SHFL_C, member_mask);
544 output_alias[WORD] = shuffle_word;
545 }
546
547 return output;
548 }
549
550
551 /**
552 * \brief Shuffle-down for any data type. Each <em>warp-lane<sub>i</sub></em> obtains the value \p input contributed by <em>warp-lane</em><sub><em>i</em>+<tt>src_offset</tt></sub>. For thread lanes \e i >= WARP_THREADS, the thread's own \p input is returned to the thread. ![](shfl_down_logo.png)
553 * \ingroup WarpModule
554 *
555 * \tparam LOGICAL_WARP_THREADS The number of threads per "logical" warp. Must be a power-of-two <= 32.
556 * \tparam T <b>[inferred]</b> The input/output element type
557 *
558 * \par
559 * - Available only for SM3.0 or newer
560 *
561 * \par Snippet
562 * The code snippet below illustrates each thread obtaining a \p double value from the
563 * successor of its successor.
564 * \par
565 * \code
566 * #include <cub/cub.cuh> // or equivalently <cub/util_ptx.cuh>
567 *
568 * __global__ void ExampleKernel(...)
569 * {
570 * // Obtain one input item per thread
571 * double thread_data = ...
572 *
573 * // Obtain item from two ranks below
574 * double peer_data = ShuffleDown<32>(thread_data, 2, 31, 0xffffffff);
575 *
576 * \endcode
577 * \par
578 * Suppose the set of input \p thread_data across the first warp of threads is <tt>{1.0, 2.0, 3.0, 4.0, 5.0, ..., 32.0}</tt>.
579 * The corresponding output \p peer_data will be <tt>{3.0, 4.0, 5.0, 6.0, 7.0, ..., 32.0}</tt>.
580 *
581 */
582 template <
583 int LOGICAL_WARP_THREADS, ///< Number of threads per logical warp
584 typename T>
ShuffleDown(T input,int src_offset,int last_thread,unsigned int member_mask)585 __device__ __forceinline__ T ShuffleDown(
586 T input, ///< [in] The value to broadcast
587 int src_offset, ///< [in] The relative up-offset of the peer to read from
588 int last_thread, ///< [in] Index of last thread in logical warp (typically 31 for a 32-thread warp)
589 unsigned int member_mask) ///< [in] 32-bit mask of participating warp lanes
590 {
591 /// The 5-bit SHFL mask for logically splitting warps into sub-segments starts 8-bits up
592 enum {
593 SHFL_C = (32 - LOGICAL_WARP_THREADS) << 8
594 };
595
596 typedef typename UnitWord<T>::ShuffleWord ShuffleWord;
597
598 const int WORDS = (sizeof(T) + sizeof(ShuffleWord) - 1) / sizeof(ShuffleWord);
599
600 T output;
601 ShuffleWord *output_alias = reinterpret_cast<ShuffleWord *>(&output);
602 ShuffleWord *input_alias = reinterpret_cast<ShuffleWord *>(&input);
603
604 unsigned int shuffle_word;
605 shuffle_word = SHFL_DOWN_SYNC((unsigned int)input_alias[0], src_offset, last_thread | SHFL_C, member_mask);
606 output_alias[0] = shuffle_word;
607
608 #pragma unroll
609 for (int WORD = 1; WORD < WORDS; ++WORD)
610 {
611 shuffle_word = SHFL_DOWN_SYNC((unsigned int)input_alias[WORD], src_offset, last_thread | SHFL_C, member_mask);
612 output_alias[WORD] = shuffle_word;
613 }
614
615 return output;
616 }
617
618
619 /**
620 * \brief Shuffle-broadcast for any data type. Each <em>warp-lane<sub>i</sub></em> obtains the value \p input
621 * contributed by <em>warp-lane</em><sub><tt>src_lane</tt></sub>. For \p src_lane < 0 or \p src_lane >= WARP_THREADS,
622 * then the thread's own \p input is returned to the thread. ![](shfl_broadcast_logo.png)
623 *
624 * \tparam LOGICAL_WARP_THREADS The number of threads per "logical" warp. Must be a power-of-two <= 32.
625 * \tparam T <b>[inferred]</b> The input/output element type
626 *
627 * \ingroup WarpModule
628 *
629 * \par
630 * - Available only for SM3.0 or newer
631 *
632 * \par Snippet
633 * The code snippet below illustrates each thread obtaining a \p double value from <em>warp-lane</em><sub>0</sub>.
634 *
635 * \par
636 * \code
637 * #include <cub/cub.cuh> // or equivalently <cub/util_ptx.cuh>
638 *
639 * __global__ void ExampleKernel(...)
640 * {
641 * // Obtain one input item per thread
642 * double thread_data = ...
643 *
644 * // Obtain item from thread 0
645 * double peer_data = ShuffleIndex<32>(thread_data, 0, 0xffffffff);
646 *
647 * \endcode
648 * \par
649 * Suppose the set of input \p thread_data across the first warp of threads is <tt>{1.0, 2.0, 3.0, 4.0, 5.0, ..., 32.0}</tt>.
650 * The corresponding output \p peer_data will be <tt>{1.0, 1.0, 1.0, 1.0, 1.0, ..., 1.0}</tt>.
651 *
652 */
653 template <
654 int LOGICAL_WARP_THREADS, ///< Number of threads per logical warp
655 typename T>
ShuffleIndex(T input,int src_lane,unsigned int member_mask)656 __device__ __forceinline__ T ShuffleIndex(
657 T input, ///< [in] The value to broadcast
658 int src_lane, ///< [in] Which warp lane is to do the broadcasting
659 unsigned int member_mask) ///< [in] 32-bit mask of participating warp lanes
660 {
661 /// The 5-bit SHFL mask for logically splitting warps into sub-segments starts 8-bits up
662 enum {
663 SHFL_C = ((32 - LOGICAL_WARP_THREADS) << 8) | (LOGICAL_WARP_THREADS - 1)
664 };
665
666 typedef typename UnitWord<T>::ShuffleWord ShuffleWord;
667
668 const int WORDS = (sizeof(T) + sizeof(ShuffleWord) - 1) / sizeof(ShuffleWord);
669
670 T output;
671 ShuffleWord *output_alias = reinterpret_cast<ShuffleWord *>(&output);
672 ShuffleWord *input_alias = reinterpret_cast<ShuffleWord *>(&input);
673
674 unsigned int shuffle_word;
675 shuffle_word = SHFL_IDX_SYNC((unsigned int)input_alias[0],
676 src_lane,
677 SHFL_C,
678 member_mask);
679
680 output_alias[0] = shuffle_word;
681
682 #pragma unroll
683 for (int WORD = 1; WORD < WORDS; ++WORD)
684 {
685 shuffle_word = SHFL_IDX_SYNC((unsigned int)input_alias[WORD],
686 src_lane,
687 SHFL_C,
688 member_mask);
689
690 output_alias[WORD] = shuffle_word;
691 }
692
693 return output;
694 }
695
696
697
698 /**
699 * Compute a 32b mask of threads having the same least-significant
700 * LABEL_BITS of \p label as the calling thread.
701 */
702 template <int LABEL_BITS>
MatchAny(unsigned int label)703 inline __device__ unsigned int MatchAny(unsigned int label)
704 {
705 unsigned int retval;
706
707 // Extract masks of common threads for each bit
708 #pragma unroll
709 for (int BIT = 0; BIT < LABEL_BITS; ++BIT)
710 {
711 unsigned int mask;
712 unsigned int current_bit = 1 << BIT;
713 asm ("{\n"
714 " .reg .pred p;\n"
715 " and.b32 %0, %1, %2;"
716 " setp.eq.u32 p, %0, %2;\n"
717 #ifdef CUB_USE_COOPERATIVE_GROUPS
718 " vote.ballot.sync.b32 %0, p, 0xffffffff;\n"
719 #else
720 " vote.ballot.b32 %0, p;\n"
721 #endif
722 " @!p not.b32 %0, %0;\n"
723 "}\n" : "=r"(mask) : "r"(label), "r"(current_bit));
724
725 // Remove peers who differ
726 retval = (BIT == 0) ? mask : retval & mask;
727 }
728
729 return retval;
730
731 // // VOLTA match
732 // unsigned int retval;
733 // asm ("{\n"
734 // " match.any.sync.b32 %0, %1, 0xffffffff;\n"
735 // "}\n" : "=r"(retval) : "r"(label));
736 // return retval;
737
738 }
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757 } // CUB namespace
758 CUB_NS_POSTFIX // Optional outer namespace(s)
759