1 /****************************************************************************** 2 * Copyright (c) 2011, Duane Merrill. All rights reserved. 3 * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. 4 * 5 * Redistribution and use in source and binary forms, with or without 6 * modification, are permitted provided that the following conditions are met: 7 * * Redistributions of source code must retain the above copyright 8 * notice, this list of conditions and the following disclaimer. 9 * * Redistributions in binary form must reproduce the above copyright 10 * notice, this list of conditions and the following disclaimer in the 11 * documentation and/or other materials provided with the distribution. 12 * * Neither the name of the NVIDIA CORPORATION nor the 13 * names of its contributors may be used to endorse or promote products 14 * derived from this software without specific prior written permission. 15 * 16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 17 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 18 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 20 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 21 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 22 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 23 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 25 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 * 27 ******************************************************************************/ 28 29 /** 30 * \file 31 * cub::WarpScanSmem provides smem-based variants of parallel prefix scan of items partitioned across a CUDA thread warp. 32 */ 33 34 #pragma once 35 36 #include "../../thread/thread_operators.cuh" 37 #include "../../thread/thread_load.cuh" 38 #include "../../thread/thread_store.cuh" 39 #include "../../util_type.cuh" 40 #include "../../util_namespace.cuh" 41 42 /// Optional outer namespace(s) 43 CUB_NS_PREFIX 44 45 /// CUB namespace 46 namespace cub { 47 48 /** 49 * \brief WarpScanSmem provides smem-based variants of parallel prefix scan of items partitioned across a CUDA thread warp. 50 */ 51 template < 52 typename T, ///< Data type being scanned 53 int LOGICAL_WARP_THREADS, ///< Number of threads per logical warp 54 int PTX_ARCH> ///< The PTX compute capability for which to to specialize this collective 55 struct WarpScanSmem 56 { 57 /****************************************************************************** 58 * Constants and type definitions 59 ******************************************************************************/ 60 61 enum 62 { 63 /// Whether the logical warp size and the PTX warp size coincide 64 IS_ARCH_WARP = (LOGICAL_WARP_THREADS == CUB_WARP_THREADS(PTX_ARCH)), 65 66 /// Whether the logical warp size is a power-of-two 67 IS_POW_OF_TWO = PowerOfTwo<LOGICAL_WARP_THREADS>::VALUE, 68 69 /// The number of warp scan steps 70 STEPS = Log2<LOGICAL_WARP_THREADS>::VALUE, 71 72 /// The number of threads in half a warp 73 HALF_WARP_THREADS = 1 << (STEPS - 1), 74 75 /// The number of shared memory elements per warp 76 WARP_SMEM_ELEMENTS = LOGICAL_WARP_THREADS + HALF_WARP_THREADS, 77 }; 78 79 /// Storage cell type (workaround for SM1x compiler bugs with custom-ops like Max() on signed chars) 80 typedef typename If<((Equals<T, char>::VALUE || Equals<T, signed char>::VALUE) && (PTX_ARCH < 200)), int, T>::Type CellT; 81 82 /// Shared memory storage layout type (1.5 warps-worth of elements for each warp) 83 typedef CellT _TempStorage[WARP_SMEM_ELEMENTS]; 84 85 // Alias wrapper allowing storage to be unioned 86 struct TempStorage : Uninitialized<_TempStorage> {}; 87 88 89 /****************************************************************************** 90 * Thread fields 91 ******************************************************************************/ 92 93 _TempStorage &temp_storage; 94 unsigned int lane_id; 95 unsigned int member_mask; 96 97 98 /****************************************************************************** 99 * Construction 100 ******************************************************************************/ 101 102 /// Constructor WarpScanSmemcub::WarpScanSmem103 __device__ __forceinline__ WarpScanSmem( 104 TempStorage &temp_storage) 105 : 106 temp_storage(temp_storage.Alias()), 107 108 lane_id(IS_ARCH_WARP ? 109 LaneId() : 110 LaneId() % LOGICAL_WARP_THREADS), 111 112 member_mask((0xffffffff >> (32 - LOGICAL_WARP_THREADS)) << ((IS_ARCH_WARP || !IS_POW_OF_TWO ) ? 113 0 : // arch-width and non-power-of-two subwarps cannot be tiled with the arch-warp 114 ((LaneId() / LOGICAL_WARP_THREADS) * LOGICAL_WARP_THREADS))) 115 {} 116 117 118 /****************************************************************************** 119 * Utility methods 120 ******************************************************************************/ 121 122 /// Basic inclusive scan iteration (template unrolled, inductive-case specialization) 123 template < 124 bool HAS_IDENTITY, 125 int STEP, 126 typename ScanOp> ScanStepcub::WarpScanSmem127 __device__ __forceinline__ void ScanStep( 128 T &partial, 129 ScanOp scan_op, 130 Int2Type<STEP> /*step*/) 131 { 132 const int OFFSET = 1 << STEP; 133 134 // Share partial into buffer 135 ThreadStore<STORE_VOLATILE>(&temp_storage[HALF_WARP_THREADS + lane_id], (CellT) partial); 136 137 WARP_SYNC(member_mask); 138 139 // Update partial if addend is in range 140 if (HAS_IDENTITY || (lane_id >= OFFSET)) 141 { 142 T addend = (T) ThreadLoad<LOAD_VOLATILE>(&temp_storage[HALF_WARP_THREADS + lane_id - OFFSET]); 143 partial = scan_op(addend, partial); 144 } 145 WARP_SYNC(member_mask); 146 147 ScanStep<HAS_IDENTITY>(partial, scan_op, Int2Type<STEP + 1>()); 148 } 149 150 151 /// Basic inclusive scan iteration(template unrolled, base-case specialization) 152 template < 153 bool HAS_IDENTITY, 154 typename ScanOp> ScanStepcub::WarpScanSmem155 __device__ __forceinline__ void ScanStep( 156 T &/*partial*/, 157 ScanOp /*scan_op*/, 158 Int2Type<STEPS> /*step*/) 159 {} 160 161 162 /// Inclusive prefix scan (specialized for summation across primitive types) InclusiveScancub::WarpScanSmem163 __device__ __forceinline__ void InclusiveScan( 164 T input, ///< [in] Calling thread's input item. 165 T &output, ///< [out] Calling thread's output item. May be aliased with \p input. 166 Sum scan_op, ///< [in] Binary scan operator 167 Int2Type<true> /*is_primitive*/) ///< [in] Marker type indicating whether T is primitive type 168 { 169 T identity = 0; 170 ThreadStore<STORE_VOLATILE>(&temp_storage[lane_id], (CellT) identity); 171 172 WARP_SYNC(member_mask); 173 174 // Iterate scan steps 175 output = input; 176 ScanStep<true>(output, scan_op, Int2Type<0>()); 177 } 178 179 180 /// Inclusive prefix scan 181 template <typename ScanOp, int IS_PRIMITIVE> InclusiveScancub::WarpScanSmem182 __device__ __forceinline__ void InclusiveScan( 183 T input, ///< [in] Calling thread's input item. 184 T &output, ///< [out] Calling thread's output item. May be aliased with \p input. 185 ScanOp scan_op, ///< [in] Binary scan operator 186 Int2Type<IS_PRIMITIVE> /*is_primitive*/) ///< [in] Marker type indicating whether T is primitive type 187 { 188 // Iterate scan steps 189 output = input; 190 ScanStep<false>(output, scan_op, Int2Type<0>()); 191 } 192 193 194 /****************************************************************************** 195 * Interface 196 ******************************************************************************/ 197 198 //--------------------------------------------------------------------- 199 // Broadcast 200 //--------------------------------------------------------------------- 201 202 /// Broadcast Broadcastcub::WarpScanSmem203 __device__ __forceinline__ T Broadcast( 204 T input, ///< [in] The value to broadcast 205 unsigned int src_lane) ///< [in] Which warp lane is to do the broadcasting 206 { 207 if (lane_id == src_lane) 208 { 209 ThreadStore<STORE_VOLATILE>(temp_storage, (CellT) input); 210 } 211 212 WARP_SYNC(member_mask); 213 214 return (T)ThreadLoad<LOAD_VOLATILE>(temp_storage); 215 } 216 217 218 //--------------------------------------------------------------------- 219 // Inclusive operations 220 //--------------------------------------------------------------------- 221 222 /// Inclusive scan 223 template <typename ScanOp> InclusiveScancub::WarpScanSmem224 __device__ __forceinline__ void InclusiveScan( 225 T input, ///< [in] Calling thread's input item. 226 T &inclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input. 227 ScanOp scan_op) ///< [in] Binary scan operator 228 { 229 InclusiveScan(input, inclusive_output, scan_op, Int2Type<Traits<T>::PRIMITIVE>()); 230 } 231 232 233 /// Inclusive scan with aggregate 234 template <typename ScanOp> InclusiveScancub::WarpScanSmem235 __device__ __forceinline__ void InclusiveScan( 236 T input, ///< [in] Calling thread's input item. 237 T &inclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input. 238 ScanOp scan_op, ///< [in] Binary scan operator 239 T &warp_aggregate) ///< [out] Warp-wide aggregate reduction of input items. 240 { 241 InclusiveScan(input, inclusive_output, scan_op); 242 243 // Retrieve aggregate 244 ThreadStore<STORE_VOLATILE>(&temp_storage[HALF_WARP_THREADS + lane_id], (CellT) inclusive_output); 245 246 WARP_SYNC(member_mask); 247 248 warp_aggregate = (T) ThreadLoad<LOAD_VOLATILE>(&temp_storage[WARP_SMEM_ELEMENTS - 1]); 249 250 WARP_SYNC(member_mask); 251 } 252 253 254 //--------------------------------------------------------------------- 255 // Get exclusive from inclusive 256 //--------------------------------------------------------------------- 257 258 /// Update inclusive and exclusive using input and inclusive 259 template <typename ScanOpT, typename IsIntegerT> Updatecub::WarpScanSmem260 __device__ __forceinline__ void Update( 261 T /*input*/, ///< [in] 262 T &inclusive, ///< [in, out] 263 T &exclusive, ///< [out] 264 ScanOpT /*scan_op*/, ///< [in] 265 IsIntegerT /*is_integer*/) ///< [in] 266 { 267 // initial value unknown 268 ThreadStore<STORE_VOLATILE>(&temp_storage[HALF_WARP_THREADS + lane_id], (CellT) inclusive); 269 270 WARP_SYNC(member_mask); 271 272 exclusive = (T) ThreadLoad<LOAD_VOLATILE>(&temp_storage[HALF_WARP_THREADS + lane_id - 1]); 273 } 274 275 /// Update inclusive and exclusive using input and inclusive (specialized for summation of integer types) Updatecub::WarpScanSmem276 __device__ __forceinline__ void Update( 277 T input, 278 T &inclusive, 279 T &exclusive, 280 cub::Sum /*scan_op*/, 281 Int2Type<true> /*is_integer*/) 282 { 283 // initial value presumed 0 284 exclusive = inclusive - input; 285 } 286 287 /// Update inclusive and exclusive using initial value using input, inclusive, and initial value 288 template <typename ScanOpT, typename IsIntegerT> Updatecub::WarpScanSmem289 __device__ __forceinline__ void Update ( 290 T /*input*/, 291 T &inclusive, 292 T &exclusive, 293 ScanOpT scan_op, 294 T initial_value, 295 IsIntegerT /*is_integer*/) 296 { 297 inclusive = scan_op(initial_value, inclusive); 298 ThreadStore<STORE_VOLATILE>(&temp_storage[HALF_WARP_THREADS + lane_id], (CellT) inclusive); 299 300 WARP_SYNC(member_mask); 301 302 exclusive = (T) ThreadLoad<LOAD_VOLATILE>(&temp_storage[HALF_WARP_THREADS + lane_id - 1]); 303 if (lane_id == 0) 304 exclusive = initial_value; 305 } 306 307 /// Update inclusive and exclusive using initial value using input and inclusive (specialized for summation of integer types) Updatecub::WarpScanSmem308 __device__ __forceinline__ void Update ( 309 T input, 310 T &inclusive, 311 T &exclusive, 312 cub::Sum scan_op, 313 T initial_value, 314 Int2Type<true> /*is_integer*/) 315 { 316 inclusive = scan_op(initial_value, inclusive); 317 exclusive = inclusive - input; 318 } 319 320 321 /// Update inclusive, exclusive, and warp aggregate using input and inclusive 322 template <typename ScanOpT, typename IsIntegerT> Updatecub::WarpScanSmem323 __device__ __forceinline__ void Update ( 324 T /*input*/, 325 T &inclusive, 326 T &exclusive, 327 T &warp_aggregate, 328 ScanOpT /*scan_op*/, 329 IsIntegerT /*is_integer*/) 330 { 331 // Initial value presumed to be unknown or identity (either way our padding is correct) 332 ThreadStore<STORE_VOLATILE>(&temp_storage[HALF_WARP_THREADS + lane_id], (CellT) inclusive); 333 334 WARP_SYNC(member_mask); 335 336 exclusive = (T) ThreadLoad<LOAD_VOLATILE>(&temp_storage[HALF_WARP_THREADS + lane_id - 1]); 337 warp_aggregate = (T) ThreadLoad<LOAD_VOLATILE>(&temp_storage[WARP_SMEM_ELEMENTS - 1]); 338 } 339 340 /// Update inclusive, exclusive, and warp aggregate using input and inclusive (specialized for summation of integer types) Updatecub::WarpScanSmem341 __device__ __forceinline__ void Update ( 342 T input, 343 T &inclusive, 344 T &exclusive, 345 T &warp_aggregate, 346 cub::Sum /*scan_o*/, 347 Int2Type<true> /*is_integer*/) 348 { 349 // Initial value presumed to be unknown or identity (either way our padding is correct) 350 ThreadStore<STORE_VOLATILE>(&temp_storage[HALF_WARP_THREADS + lane_id], (CellT) inclusive); 351 352 WARP_SYNC(member_mask); 353 354 warp_aggregate = (T) ThreadLoad<LOAD_VOLATILE>(&temp_storage[WARP_SMEM_ELEMENTS - 1]); 355 exclusive = inclusive - input; 356 } 357 358 /// Update inclusive, exclusive, and warp aggregate using input, inclusive, and initial value 359 template <typename ScanOpT, typename IsIntegerT> Updatecub::WarpScanSmem360 __device__ __forceinline__ void Update ( 361 T /*input*/, 362 T &inclusive, 363 T &exclusive, 364 T &warp_aggregate, 365 ScanOpT scan_op, 366 T initial_value, 367 IsIntegerT /*is_integer*/) 368 { 369 // Broadcast warp aggregate 370 ThreadStore<STORE_VOLATILE>(&temp_storage[HALF_WARP_THREADS + lane_id], (CellT) inclusive); 371 372 WARP_SYNC(member_mask); 373 374 warp_aggregate = (T) ThreadLoad<LOAD_VOLATILE>(&temp_storage[WARP_SMEM_ELEMENTS - 1]); 375 376 WARP_SYNC(member_mask); 377 378 // Update inclusive with initial value 379 inclusive = scan_op(initial_value, inclusive); 380 381 // Get exclusive from exclusive 382 ThreadStore<STORE_VOLATILE>(&temp_storage[HALF_WARP_THREADS + lane_id - 1], (CellT) inclusive); 383 384 WARP_SYNC(member_mask); 385 386 exclusive = (T) ThreadLoad<LOAD_VOLATILE>(&temp_storage[HALF_WARP_THREADS + lane_id - 2]); 387 388 if (lane_id == 0) 389 exclusive = initial_value; 390 } 391 392 393 }; 394 395 396 } // CUB namespace 397 CUB_NS_POSTFIX // Optional outer namespace(s) 398