1 /****************************************************************************** 2 * Copyright (c) 2011, Duane Merrill. All rights reserved. 3 * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. 4 * 5 * Redistribution and use in source and binary forms, with or without 6 * modification, are permitted provided that the following conditions are met: 7 * * Redistributions of source code must retain the above copyright 8 * notice, this list of conditions and the following disclaimer. 9 * * Redistributions in binary form must reproduce the above copyright 10 * notice, this list of conditions and the following disclaimer in the 11 * documentation and/or other materials provided with the distribution. 12 * * Neither the name of the NVIDIA CORPORATION nor the 13 * names of its contributors may be used to endorse or promote products 14 * derived from this software without specific prior written permission. 15 * 16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 17 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 18 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 20 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 21 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 22 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 23 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 25 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 * 27 ******************************************************************************/ 28 29 /** 30 * \file 31 * cub::WarpReduceShfl provides SHFL-based variants of parallel reduction of items partitioned across a CUDA thread warp. 32 */ 33 34 #pragma once 35 36 #include "../../thread/thread_operators.cuh" 37 #include "../../util_ptx.cuh" 38 #include "../../util_type.cuh" 39 #include "../../util_macro.cuh" 40 #include "../../util_namespace.cuh" 41 42 /// Optional outer namespace(s) 43 CUB_NS_PREFIX 44 45 /// CUB namespace 46 namespace cub { 47 48 49 /** 50 * \brief WarpReduceShfl provides SHFL-based variants of parallel reduction of items partitioned across a CUDA thread warp. 51 * 52 * LOGICAL_WARP_THREADS must be a power-of-two 53 */ 54 template < 55 typename T, ///< Data type being reduced 56 int LOGICAL_WARP_THREADS, ///< Number of threads per logical warp 57 int PTX_ARCH> ///< The PTX compute capability for which to to specialize this collective 58 struct WarpReduceShfl 59 { 60 //--------------------------------------------------------------------- 61 // Constants and type definitions 62 //--------------------------------------------------------------------- 63 64 enum 65 { 66 /// Whether the logical warp size and the PTX warp size coincide 67 IS_ARCH_WARP = (LOGICAL_WARP_THREADS == CUB_WARP_THREADS(PTX_ARCH)), 68 69 /// The number of warp reduction steps 70 STEPS = Log2<LOGICAL_WARP_THREADS>::VALUE, 71 72 /// Number of logical warps in a PTX warp 73 LOGICAL_WARPS = CUB_WARP_THREADS(PTX_ARCH) / LOGICAL_WARP_THREADS, 74 75 /// The 5-bit SHFL mask for logically splitting warps into sub-segments starts 8-bits up 76 SHFL_C = (CUB_WARP_THREADS(PTX_ARCH) - LOGICAL_WARP_THREADS) << 8 77 78 }; 79 80 template <typename S> 81 struct IsInteger 82 { 83 enum { 84 ///Whether the data type is a small (32b or less) integer for which we can use a single SFHL instruction per exchange 85 IS_SMALL_UNSIGNED = (Traits<S>::CATEGORY == UNSIGNED_INTEGER) && (sizeof(S) <= sizeof(unsigned int)) 86 }; 87 }; 88 89 90 /// Shared memory storage layout type 91 typedef NullType TempStorage; 92 93 94 //--------------------------------------------------------------------- 95 // Thread fields 96 //--------------------------------------------------------------------- 97 98 /// Lane index in logical warp 99 unsigned int lane_id; 100 101 /// Logical warp index in 32-thread physical warp 102 unsigned int warp_id; 103 104 /// 32-thread physical warp member mask of logical warp 105 unsigned int member_mask; 106 107 108 //--------------------------------------------------------------------- 109 // Construction 110 //--------------------------------------------------------------------- 111 112 /// Constructor WarpReduceShflcub::WarpReduceShfl113 __device__ __forceinline__ WarpReduceShfl( 114 TempStorage &/*temp_storage*/) 115 { 116 lane_id = LaneId(); 117 warp_id = 0; 118 member_mask = 0xffffffffu >> (CUB_WARP_THREADS(PTX_ARCH) - LOGICAL_WARP_THREADS); 119 120 if (!IS_ARCH_WARP) 121 { 122 warp_id = lane_id / LOGICAL_WARP_THREADS; 123 lane_id = lane_id % LOGICAL_WARP_THREADS; 124 member_mask = member_mask << (warp_id * LOGICAL_WARP_THREADS); 125 } 126 } 127 128 129 //--------------------------------------------------------------------- 130 // Reduction steps 131 //--------------------------------------------------------------------- 132 133 /// Reduction (specialized for summation across uint32 types) ReduceStepcub::WarpReduceShfl134 __device__ __forceinline__ unsigned int ReduceStep( 135 unsigned int input, ///< [in] Calling thread's input item. 136 cub::Sum /*reduction_op*/, ///< [in] Binary reduction operator 137 int last_lane, ///< [in] Index of last lane in segment 138 int offset) ///< [in] Up-offset to pull from 139 { 140 unsigned int output; 141 int shfl_c = last_lane | SHFL_C; // Shuffle control (mask and last_lane) 142 143 // Use predicate set from SHFL to guard against invalid peers 144 #ifdef CUB_USE_COOPERATIVE_GROUPS 145 asm volatile( 146 "{" 147 " .reg .u32 r0;" 148 " .reg .pred p;" 149 " shfl.sync.down.b32 r0|p, %1, %2, %3, %5;" 150 " @p add.u32 r0, r0, %4;" 151 " mov.u32 %0, r0;" 152 "}" 153 : "=r"(output) : "r"(input), "r"(offset), "r"(shfl_c), "r"(input), "r"(member_mask)); 154 #else 155 asm volatile( 156 "{" 157 " .reg .u32 r0;" 158 " .reg .pred p;" 159 " shfl.down.b32 r0|p, %1, %2, %3;" 160 " @p add.u32 r0, r0, %4;" 161 " mov.u32 %0, r0;" 162 "}" 163 : "=r"(output) : "r"(input), "r"(offset), "r"(shfl_c), "r"(input)); 164 #endif 165 166 return output; 167 } 168 169 170 /// Reduction (specialized for summation across fp32 types) ReduceStepcub::WarpReduceShfl171 __device__ __forceinline__ float ReduceStep( 172 float input, ///< [in] Calling thread's input item. 173 cub::Sum /*reduction_op*/, ///< [in] Binary reduction operator 174 int last_lane, ///< [in] Index of last lane in segment 175 int offset) ///< [in] Up-offset to pull from 176 { 177 float output; 178 int shfl_c = last_lane | SHFL_C; // Shuffle control (mask and last_lane) 179 180 // Use predicate set from SHFL to guard against invalid peers 181 #ifdef CUB_USE_COOPERATIVE_GROUPS 182 asm volatile( 183 "{" 184 " .reg .f32 r0;" 185 " .reg .pred p;" 186 " shfl.sync.down.b32 r0|p, %1, %2, %3, %5;" 187 " @p add.f32 r0, r0, %4;" 188 " mov.f32 %0, r0;" 189 "}" 190 : "=f"(output) : "f"(input), "r"(offset), "r"(shfl_c), "f"(input), "r"(member_mask)); 191 #else 192 asm volatile( 193 "{" 194 " .reg .f32 r0;" 195 " .reg .pred p;" 196 " shfl.down.b32 r0|p, %1, %2, %3;" 197 " @p add.f32 r0, r0, %4;" 198 " mov.f32 %0, r0;" 199 "}" 200 : "=f"(output) : "f"(input), "r"(offset), "r"(shfl_c), "f"(input)); 201 #endif 202 203 return output; 204 } 205 206 207 /// Reduction (specialized for summation across unsigned long long types) ReduceStepcub::WarpReduceShfl208 __device__ __forceinline__ unsigned long long ReduceStep( 209 unsigned long long input, ///< [in] Calling thread's input item. 210 cub::Sum /*reduction_op*/, ///< [in] Binary reduction operator 211 int last_lane, ///< [in] Index of last lane in segment 212 int offset) ///< [in] Up-offset to pull from 213 { 214 unsigned long long output; 215 int shfl_c = last_lane | SHFL_C; // Shuffle control (mask and last_lane) 216 217 #ifdef CUB_USE_COOPERATIVE_GROUPS 218 asm volatile( 219 "{" 220 " .reg .u32 lo;" 221 " .reg .u32 hi;" 222 " .reg .pred p;" 223 " mov.b64 {lo, hi}, %1;" 224 " shfl.sync.down.b32 lo|p, lo, %2, %3, %4;" 225 " shfl.sync.down.b32 hi|p, hi, %2, %3, %4;" 226 " mov.b64 %0, {lo, hi};" 227 " @p add.u64 %0, %0, %1;" 228 "}" 229 : "=l"(output) : "l"(input), "r"(offset), "r"(shfl_c), "r"(member_mask)); 230 #else 231 asm volatile( 232 "{" 233 " .reg .u32 lo;" 234 " .reg .u32 hi;" 235 " .reg .pred p;" 236 " mov.b64 {lo, hi}, %1;" 237 " shfl.down.b32 lo|p, lo, %2, %3;" 238 " shfl.down.b32 hi|p, hi, %2, %3;" 239 " mov.b64 %0, {lo, hi};" 240 " @p add.u64 %0, %0, %1;" 241 "}" 242 : "=l"(output) : "l"(input), "r"(offset), "r"(shfl_c)); 243 #endif 244 245 return output; 246 } 247 248 249 /// Reduction (specialized for summation across long long types) ReduceStepcub::WarpReduceShfl250 __device__ __forceinline__ long long ReduceStep( 251 long long input, ///< [in] Calling thread's input item. 252 cub::Sum /*reduction_op*/, ///< [in] Binary reduction operator 253 int last_lane, ///< [in] Index of last lane in segment 254 int offset) ///< [in] Up-offset to pull from 255 { 256 long long output; 257 int shfl_c = last_lane | SHFL_C; // Shuffle control (mask and last_lane) 258 259 // Use predicate set from SHFL to guard against invalid peers 260 #ifdef CUB_USE_COOPERATIVE_GROUPS 261 asm volatile( 262 "{" 263 " .reg .u32 lo;" 264 " .reg .u32 hi;" 265 " .reg .pred p;" 266 " mov.b64 {lo, hi}, %1;" 267 " shfl.sync.down.b32 lo|p, lo, %2, %3, %4;" 268 " shfl.sync.down.b32 hi|p, hi, %2, %3, %4;" 269 " mov.b64 %0, {lo, hi};" 270 " @p add.s64 %0, %0, %1;" 271 "}" 272 : "=l"(output) : "l"(input), "r"(offset), "r"(shfl_c), "r"(member_mask)); 273 #else 274 asm volatile( 275 "{" 276 " .reg .u32 lo;" 277 " .reg .u32 hi;" 278 " .reg .pred p;" 279 " mov.b64 {lo, hi}, %1;" 280 " shfl.down.b32 lo|p, lo, %2, %3;" 281 " shfl.down.b32 hi|p, hi, %2, %3;" 282 " mov.b64 %0, {lo, hi};" 283 " @p add.s64 %0, %0, %1;" 284 "}" 285 : "=l"(output) : "l"(input), "r"(offset), "r"(shfl_c)); 286 #endif 287 288 return output; 289 } 290 291 292 /// Reduction (specialized for summation across double types) ReduceStepcub::WarpReduceShfl293 __device__ __forceinline__ double ReduceStep( 294 double input, ///< [in] Calling thread's input item. 295 cub::Sum /*reduction_op*/, ///< [in] Binary reduction operator 296 int last_lane, ///< [in] Index of last lane in segment 297 int offset) ///< [in] Up-offset to pull from 298 { 299 double output; 300 int shfl_c = last_lane | SHFL_C; // Shuffle control (mask and last_lane) 301 302 // Use predicate set from SHFL to guard against invalid peers 303 #ifdef CUB_USE_COOPERATIVE_GROUPS 304 asm volatile( 305 "{" 306 " .reg .u32 lo;" 307 " .reg .u32 hi;" 308 " .reg .pred p;" 309 " .reg .f64 r0;" 310 " mov.b64 %0, %1;" 311 " mov.b64 {lo, hi}, %1;" 312 " shfl.sync.down.b32 lo|p, lo, %2, %3, %4;" 313 " shfl.sync.down.b32 hi|p, hi, %2, %3, %4;" 314 " mov.b64 r0, {lo, hi};" 315 " @p add.f64 %0, %0, r0;" 316 "}" 317 : "=d"(output) : "d"(input), "r"(offset), "r"(shfl_c), "r"(member_mask)); 318 #else 319 asm volatile( 320 "{" 321 " .reg .u32 lo;" 322 " .reg .u32 hi;" 323 " .reg .pred p;" 324 " .reg .f64 r0;" 325 " mov.b64 %0, %1;" 326 " mov.b64 {lo, hi}, %1;" 327 " shfl.down.b32 lo|p, lo, %2, %3;" 328 " shfl.down.b32 hi|p, hi, %2, %3;" 329 " mov.b64 r0, {lo, hi};" 330 " @p add.f64 %0, %0, r0;" 331 "}" 332 : "=d"(output) : "d"(input), "r"(offset), "r"(shfl_c)); 333 #endif 334 335 return output; 336 } 337 338 339 /// Reduction (specialized for swizzled ReduceByKeyOp<cub::Sum> across KeyValuePair<KeyT, ValueT> types) 340 template <typename ValueT, typename KeyT> ReduceStepcub::WarpReduceShfl341 __device__ __forceinline__ KeyValuePair<KeyT, ValueT> ReduceStep( 342 KeyValuePair<KeyT, ValueT> input, ///< [in] Calling thread's input item. 343 SwizzleScanOp<ReduceByKeyOp<cub::Sum> > /*reduction_op*/, ///< [in] Binary reduction operator 344 int last_lane, ///< [in] Index of last lane in segment 345 int offset) ///< [in] Up-offset to pull from 346 { 347 KeyValuePair<KeyT, ValueT> output; 348 349 KeyT other_key = ShuffleDown<LOGICAL_WARP_THREADS>(input.key, offset, last_lane, member_mask); 350 351 output.key = input.key; 352 output.value = ReduceStep( 353 input.value, 354 cub::Sum(), 355 last_lane, 356 offset, 357 Int2Type<IsInteger<ValueT>::IS_SMALL_UNSIGNED>()); 358 359 if (input.key != other_key) 360 output.value = input.value; 361 362 return output; 363 } 364 365 366 367 /// Reduction (specialized for swizzled ReduceBySegmentOp<cub::Sum> across KeyValuePair<OffsetT, ValueT> types) 368 template <typename ValueT, typename OffsetT> ReduceStepcub::WarpReduceShfl369 __device__ __forceinline__ KeyValuePair<OffsetT, ValueT> ReduceStep( 370 KeyValuePair<OffsetT, ValueT> input, ///< [in] Calling thread's input item. 371 SwizzleScanOp<ReduceBySegmentOp<cub::Sum> > /*reduction_op*/, ///< [in] Binary reduction operator 372 int last_lane, ///< [in] Index of last lane in segment 373 int offset) ///< [in] Up-offset to pull from 374 { 375 KeyValuePair<OffsetT, ValueT> output; 376 377 output.value = ReduceStep(input.value, cub::Sum(), last_lane, offset, Int2Type<IsInteger<ValueT>::IS_SMALL_UNSIGNED>()); 378 output.key = ReduceStep(input.key, cub::Sum(), last_lane, offset, Int2Type<IsInteger<OffsetT>::IS_SMALL_UNSIGNED>()); 379 380 if (input.key > 0) 381 output.value = input.value; 382 383 return output; 384 } 385 386 387 /// Reduction step (generic) 388 template <typename _T, typename ReductionOp> ReduceStepcub::WarpReduceShfl389 __device__ __forceinline__ _T ReduceStep( 390 _T input, ///< [in] Calling thread's input item. 391 ReductionOp reduction_op, ///< [in] Binary reduction operator 392 int last_lane, ///< [in] Index of last lane in segment 393 int offset) ///< [in] Up-offset to pull from 394 { 395 _T output = input; 396 397 _T temp = ShuffleDown<LOGICAL_WARP_THREADS>(output, offset, last_lane, member_mask); 398 399 // Perform reduction op if valid 400 if (offset + lane_id <= last_lane) 401 output = reduction_op(input, temp); 402 403 return output; 404 } 405 406 407 /// Reduction step (specialized for small unsigned integers size 32b or less) 408 template <typename _T, typename ReductionOp> ReduceStepcub::WarpReduceShfl409 __device__ __forceinline__ _T ReduceStep( 410 _T input, ///< [in] Calling thread's input item. 411 ReductionOp reduction_op, ///< [in] Binary reduction operator 412 int last_lane, ///< [in] Index of last lane in segment 413 int offset, ///< [in] Up-offset to pull from 414 Int2Type<true> /*is_small_unsigned*/) ///< [in] Marker type indicating whether T is a small unsigned integer 415 { 416 return ReduceStep(input, reduction_op, last_lane, offset); 417 } 418 419 420 /// Reduction step (specialized for types other than small unsigned integers size 32b or less) 421 template <typename _T, typename ReductionOp> ReduceStepcub::WarpReduceShfl422 __device__ __forceinline__ _T ReduceStep( 423 _T input, ///< [in] Calling thread's input item. 424 ReductionOp reduction_op, ///< [in] Binary reduction operator 425 int last_lane, ///< [in] Index of last lane in segment 426 int offset, ///< [in] Up-offset to pull from 427 Int2Type<false> /*is_small_unsigned*/) ///< [in] Marker type indicating whether T is a small unsigned integer 428 { 429 return ReduceStep(input, reduction_op, last_lane, offset); 430 } 431 432 433 //--------------------------------------------------------------------- 434 // Templated inclusive scan iteration 435 //--------------------------------------------------------------------- 436 437 template <typename ReductionOp, int STEP> ReduceStepcub::WarpReduceShfl438 __device__ __forceinline__ void ReduceStep( 439 T& input, ///< [in] Calling thread's input item. 440 ReductionOp reduction_op, ///< [in] Binary reduction operator 441 int last_lane, ///< [in] Index of last lane in segment 442 Int2Type<STEP> /*step*/) 443 { 444 input = ReduceStep(input, reduction_op, last_lane, 1 << STEP, Int2Type<IsInteger<T>::IS_SMALL_UNSIGNED>()); 445 446 ReduceStep(input, reduction_op, last_lane, Int2Type<STEP + 1>()); 447 } 448 449 template <typename ReductionOp> ReduceStepcub::WarpReduceShfl450 __device__ __forceinline__ void ReduceStep( 451 T& /*input*/, ///< [in] Calling thread's input item. 452 ReductionOp /*reduction_op*/, ///< [in] Binary reduction operator 453 int /*last_lane*/, ///< [in] Index of last lane in segment 454 Int2Type<STEPS> /*step*/) 455 {} 456 457 458 //--------------------------------------------------------------------- 459 // Reduction operations 460 //--------------------------------------------------------------------- 461 462 /// Reduction 463 template < 464 bool ALL_LANES_VALID, ///< Whether all lanes in each warp are contributing a valid fold of items 465 typename ReductionOp> Reducecub::WarpReduceShfl466 __device__ __forceinline__ T Reduce( 467 T input, ///< [in] Calling thread's input 468 int valid_items, ///< [in] Total number of valid items across the logical warp 469 ReductionOp reduction_op) ///< [in] Binary reduction operator 470 { 471 int last_lane = (ALL_LANES_VALID) ? 472 LOGICAL_WARP_THREADS - 1 : 473 valid_items - 1; 474 475 T output = input; 476 477 // // Iterate reduction steps 478 // #pragma unroll 479 // for (int STEP = 0; STEP < STEPS; STEP++) 480 // { 481 // output = ReduceStep(output, reduction_op, last_lane, 1 << STEP, Int2Type<IsInteger<T>::IS_SMALL_UNSIGNED>()); 482 // } 483 484 // Template-iterate reduction steps 485 ReduceStep(output, reduction_op, last_lane, Int2Type<0>()); 486 487 return output; 488 } 489 490 491 /// Segmented reduction 492 template < 493 bool HEAD_SEGMENTED, ///< Whether flags indicate a segment-head or a segment-tail 494 typename FlagT, 495 typename ReductionOp> SegmentedReducecub::WarpReduceShfl496 __device__ __forceinline__ T SegmentedReduce( 497 T input, ///< [in] Calling thread's input 498 FlagT flag, ///< [in] Whether or not the current lane is a segment head/tail 499 ReductionOp reduction_op) ///< [in] Binary reduction operator 500 { 501 // Get the start flags for each thread in the warp. 502 int warp_flags = WARP_BALLOT(flag, member_mask); 503 504 // Convert to tail-segmented 505 if (HEAD_SEGMENTED) 506 warp_flags >>= 1; 507 508 // Mask out the bits below the current thread 509 warp_flags &= LaneMaskGe(); 510 511 // Mask of physical lanes outside the logical warp and convert to logical lanemask 512 if (!IS_ARCH_WARP) 513 { 514 warp_flags = (warp_flags & member_mask) >> (warp_id * LOGICAL_WARP_THREADS); 515 } 516 517 // Mask in the last lane of logical warp 518 warp_flags |= 1u << (LOGICAL_WARP_THREADS - 1); 519 520 // Find the next set flag 521 int last_lane = __clz(__brev(warp_flags)); 522 523 T output = input; 524 525 // // Iterate reduction steps 526 // #pragma unroll 527 // for (int STEP = 0; STEP < STEPS; STEP++) 528 // { 529 // output = ReduceStep(output, reduction_op, last_lane, 1 << STEP, Int2Type<IsInteger<T>::IS_SMALL_UNSIGNED>()); 530 // } 531 532 // Template-iterate reduction steps 533 ReduceStep(output, reduction_op, last_lane, Int2Type<0>()); 534 535 return output; 536 } 537 }; 538 539 540 } // CUB namespace 541 CUB_NS_POSTFIX // Optional outer namespace(s) 542