1 /****************************************************************************** 2 * Copyright (c) 2011, Duane Merrill. All rights reserved. 3 * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. 4 * 5 * Redistribution and use in source and binary forms, with or without 6 * modification, are permitted provided that the following conditions are met: 7 * * Redistributions of source code must retain the above copyright 8 * notice, this list of conditions and the following disclaimer. 9 * * Redistributions in binary form must reproduce the above copyright 10 * notice, this list of conditions and the following disclaimer in the 11 * documentation and/or other materials provided with the distribution. 12 * * Neither the name of the NVIDIA CORPORATION nor the 13 * names of its contributors may be used to endorse or promote products 14 * derived from this software without specific prior written permission. 15 * 16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 17 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 18 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 20 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 21 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 22 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 23 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 25 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 * 27 ******************************************************************************/ 28 29 /** 30 * \file 31 * cub::AgentScan implements a stateful abstraction of CUDA thread blocks for participating in device-wide prefix scan . 32 */ 33 34 #pragma once 35 36 #include <iterator> 37 38 #include "single_pass_scan_operators.cuh" 39 #include "../block/block_load.cuh" 40 #include "../block/block_store.cuh" 41 #include "../block/block_scan.cuh" 42 #include "../grid/grid_queue.cuh" 43 #include "../iterator/cache_modified_input_iterator.cuh" 44 #include "../util_namespace.cuh" 45 46 /// Optional outer namespace(s) 47 CUB_NS_PREFIX 48 49 /// CUB namespace 50 namespace cub { 51 52 53 /****************************************************************************** 54 * Tuning policy types 55 ******************************************************************************/ 56 57 /** 58 * Parameterizable tuning policy type for AgentScan 59 */ 60 template < 61 int _BLOCK_THREADS, ///< Threads per thread block 62 int _ITEMS_PER_THREAD, ///< Items per thread (per tile of input) 63 BlockLoadAlgorithm _LOAD_ALGORITHM, ///< The BlockLoad algorithm to use 64 CacheLoadModifier _LOAD_MODIFIER, ///< Cache load modifier for reading input elements 65 BlockStoreAlgorithm _STORE_ALGORITHM, ///< The BlockStore algorithm to use 66 BlockScanAlgorithm _SCAN_ALGORITHM> ///< The BlockScan algorithm to use 67 struct AgentScanPolicy 68 { 69 enum 70 { 71 BLOCK_THREADS = _BLOCK_THREADS, ///< Threads per thread block 72 ITEMS_PER_THREAD = _ITEMS_PER_THREAD, ///< Items per thread (per tile of input) 73 }; 74 75 static const BlockLoadAlgorithm LOAD_ALGORITHM = _LOAD_ALGORITHM; ///< The BlockLoad algorithm to use 76 static const CacheLoadModifier LOAD_MODIFIER = _LOAD_MODIFIER; ///< Cache load modifier for reading input elements 77 static const BlockStoreAlgorithm STORE_ALGORITHM = _STORE_ALGORITHM; ///< The BlockStore algorithm to use 78 static const BlockScanAlgorithm SCAN_ALGORITHM = _SCAN_ALGORITHM; ///< The BlockScan algorithm to use 79 }; 80 81 82 83 84 /****************************************************************************** 85 * Thread block abstractions 86 ******************************************************************************/ 87 88 /** 89 * \brief AgentScan implements a stateful abstraction of CUDA thread blocks for participating in device-wide prefix scan . 90 */ 91 template < 92 typename AgentScanPolicyT, ///< Parameterized AgentScanPolicyT tuning policy type 93 typename InputIteratorT, ///< Random-access input iterator type 94 typename OutputIteratorT, ///< Random-access output iterator type 95 typename ScanOpT, ///< Scan functor type 96 typename InitValueT, ///< The init_value element for ScanOpT type (cub::NullType for inclusive scan) 97 typename OffsetT> ///< Signed integer type for global offsets 98 struct AgentScan 99 { 100 //--------------------------------------------------------------------- 101 // Types and constants 102 //--------------------------------------------------------------------- 103 104 // The input value type 105 typedef typename std::iterator_traits<InputIteratorT>::value_type InputT; 106 107 // The output value type 108 typedef typename If<(Equals<typename std::iterator_traits<OutputIteratorT>::value_type, void>::VALUE), // OutputT = (if output iterator's value type is void) ? 109 typename std::iterator_traits<InputIteratorT>::value_type, // ... then the input iterator's value type, 110 typename std::iterator_traits<OutputIteratorT>::value_type>::Type OutputT; // ... else the output iterator's value type 111 112 // Tile status descriptor interface type 113 typedef ScanTileState<OutputT> ScanTileStateT; 114 115 // Input iterator wrapper type (for applying cache modifier) 116 typedef typename If<IsPointer<InputIteratorT>::VALUE, 117 CacheModifiedInputIterator<AgentScanPolicyT::LOAD_MODIFIER, InputT, OffsetT>, // Wrap the native input pointer with CacheModifiedInputIterator 118 InputIteratorT>::Type // Directly use the supplied input iterator type 119 WrappedInputIteratorT; 120 121 // Constants 122 enum 123 { 124 IS_INCLUSIVE = Equals<InitValueT, NullType>::VALUE, // Inclusive scan if no init_value type is provided 125 BLOCK_THREADS = AgentScanPolicyT::BLOCK_THREADS, 126 ITEMS_PER_THREAD = AgentScanPolicyT::ITEMS_PER_THREAD, 127 TILE_ITEMS = BLOCK_THREADS * ITEMS_PER_THREAD, 128 }; 129 130 // Parameterized BlockLoad type 131 typedef BlockLoad< 132 OutputT, 133 AgentScanPolicyT::BLOCK_THREADS, 134 AgentScanPolicyT::ITEMS_PER_THREAD, 135 AgentScanPolicyT::LOAD_ALGORITHM> 136 BlockLoadT; 137 138 // Parameterized BlockStore type 139 typedef BlockStore< 140 OutputT, 141 AgentScanPolicyT::BLOCK_THREADS, 142 AgentScanPolicyT::ITEMS_PER_THREAD, 143 AgentScanPolicyT::STORE_ALGORITHM> 144 BlockStoreT; 145 146 // Parameterized BlockScan type 147 typedef BlockScan< 148 OutputT, 149 AgentScanPolicyT::BLOCK_THREADS, 150 AgentScanPolicyT::SCAN_ALGORITHM> 151 BlockScanT; 152 153 // Callback type for obtaining tile prefix during block scan 154 typedef TilePrefixCallbackOp< 155 OutputT, 156 ScanOpT, 157 ScanTileStateT> 158 TilePrefixCallbackOpT; 159 160 // Stateful BlockScan prefix callback type for managing a running total while scanning consecutive tiles 161 typedef BlockScanRunningPrefixOp< 162 OutputT, 163 ScanOpT> 164 RunningPrefixCallbackOp; 165 166 // Shared memory type for this thread block 167 union _TempStorage 168 { 169 typename BlockLoadT::TempStorage load; // Smem needed for tile loading 170 typename BlockStoreT::TempStorage store; // Smem needed for tile storing 171 172 struct 173 { 174 typename TilePrefixCallbackOpT::TempStorage prefix; // Smem needed for cooperative prefix callback 175 typename BlockScanT::TempStorage scan; // Smem needed for tile scanning 176 }; 177 }; 178 179 // Alias wrapper allowing storage to be unioned 180 struct TempStorage : Uninitialized<_TempStorage> {}; 181 182 183 //--------------------------------------------------------------------- 184 // Per-thread fields 185 //--------------------------------------------------------------------- 186 187 _TempStorage& temp_storage; ///< Reference to temp_storage 188 WrappedInputIteratorT d_in; ///< Input data 189 OutputIteratorT d_out; ///< Output data 190 ScanOpT scan_op; ///< Binary scan operator 191 InitValueT init_value; ///< The init_value element for ScanOpT 192 193 194 //--------------------------------------------------------------------- 195 // Block scan utility methods 196 //--------------------------------------------------------------------- 197 198 /** 199 * Exclusive scan specialization (first tile) 200 */ 201 __device__ __forceinline__ ScanTilecub::AgentScan202 void ScanTile( 203 OutputT (&items)[ITEMS_PER_THREAD], 204 OutputT init_value, 205 ScanOpT scan_op, 206 OutputT &block_aggregate, 207 Int2Type<false> /*is_inclusive*/) 208 { 209 BlockScanT(temp_storage.scan).ExclusiveScan(items, items, init_value, scan_op, block_aggregate); 210 block_aggregate = scan_op(init_value, block_aggregate); 211 } 212 213 214 /** 215 * Inclusive scan specialization (first tile) 216 */ 217 __device__ __forceinline__ ScanTilecub::AgentScan218 void ScanTile( 219 OutputT (&items)[ITEMS_PER_THREAD], 220 InitValueT /*init_value*/, 221 ScanOpT scan_op, 222 OutputT &block_aggregate, 223 Int2Type<true> /*is_inclusive*/) 224 { 225 BlockScanT(temp_storage.scan).InclusiveScan(items, items, scan_op, block_aggregate); 226 } 227 228 229 /** 230 * Exclusive scan specialization (subsequent tiles) 231 */ 232 template <typename PrefixCallback> 233 __device__ __forceinline__ ScanTilecub::AgentScan234 void ScanTile( 235 OutputT (&items)[ITEMS_PER_THREAD], 236 ScanOpT scan_op, 237 PrefixCallback &prefix_op, 238 Int2Type<false> /*is_inclusive*/) 239 { 240 BlockScanT(temp_storage.scan).ExclusiveScan(items, items, scan_op, prefix_op); 241 } 242 243 244 /** 245 * Inclusive scan specialization (subsequent tiles) 246 */ 247 template <typename PrefixCallback> 248 __device__ __forceinline__ ScanTilecub::AgentScan249 void ScanTile( 250 OutputT (&items)[ITEMS_PER_THREAD], 251 ScanOpT scan_op, 252 PrefixCallback &prefix_op, 253 Int2Type<true> /*is_inclusive*/) 254 { 255 BlockScanT(temp_storage.scan).InclusiveScan(items, items, scan_op, prefix_op); 256 } 257 258 259 //--------------------------------------------------------------------- 260 // Constructor 261 //--------------------------------------------------------------------- 262 263 // Constructor 264 __device__ __forceinline__ AgentScancub::AgentScan265 AgentScan( 266 TempStorage& temp_storage, ///< Reference to temp_storage 267 InputIteratorT d_in, ///< Input data 268 OutputIteratorT d_out, ///< Output data 269 ScanOpT scan_op, ///< Binary scan operator 270 InitValueT init_value) ///< Initial value to seed the exclusive scan 271 : 272 temp_storage(temp_storage.Alias()), 273 d_in(d_in), 274 d_out(d_out), 275 scan_op(scan_op), 276 init_value(init_value) 277 {} 278 279 280 //--------------------------------------------------------------------- 281 // Cooperatively scan a device-wide sequence of tiles with other CTAs 282 //--------------------------------------------------------------------- 283 284 /** 285 * Process a tile of input (dynamic chained scan) 286 */ 287 template <bool IS_LAST_TILE> ///< Whether the current tile is the last tile ConsumeTilecub::AgentScan288 __device__ __forceinline__ void ConsumeTile( 289 OffsetT num_remaining, ///< Number of global input items remaining (including this tile) 290 int tile_idx, ///< Tile index 291 OffsetT tile_offset, ///< Tile offset 292 ScanTileStateT& tile_state) ///< Global tile state descriptor 293 { 294 // Load items 295 OutputT items[ITEMS_PER_THREAD]; 296 297 if (IS_LAST_TILE) 298 BlockLoadT(temp_storage.load).Load(d_in + tile_offset, items, num_remaining); 299 else 300 BlockLoadT(temp_storage.load).Load(d_in + tile_offset, items); 301 302 CTA_SYNC(); 303 304 // Perform tile scan 305 if (tile_idx == 0) 306 { 307 // Scan first tile 308 OutputT block_aggregate; 309 ScanTile(items, init_value, scan_op, block_aggregate, Int2Type<IS_INCLUSIVE>()); 310 if ((!IS_LAST_TILE) && (threadIdx.x == 0)) 311 tile_state.SetInclusive(0, block_aggregate); 312 } 313 else 314 { 315 // Scan non-first tile 316 TilePrefixCallbackOpT prefix_op(tile_state, temp_storage.prefix, scan_op, tile_idx); 317 ScanTile(items, scan_op, prefix_op, Int2Type<IS_INCLUSIVE>()); 318 } 319 320 CTA_SYNC(); 321 322 // Store items 323 if (IS_LAST_TILE) 324 BlockStoreT(temp_storage.store).Store(d_out + tile_offset, items, num_remaining); 325 else 326 BlockStoreT(temp_storage.store).Store(d_out + tile_offset, items); 327 } 328 329 330 /** 331 * Scan tiles of items as part of a dynamic chained scan 332 */ ConsumeRangecub::AgentScan333 __device__ __forceinline__ void ConsumeRange( 334 int num_items, ///< Total number of input items 335 ScanTileStateT& tile_state, ///< Global tile state descriptor 336 int start_tile) ///< The starting tile for the current grid 337 { 338 // Blocks are launched in increasing order, so just assign one tile per block 339 int tile_idx = start_tile + blockIdx.x; // Current tile index 340 OffsetT tile_offset = OffsetT(TILE_ITEMS) * tile_idx; // Global offset for the current tile 341 OffsetT num_remaining = num_items - tile_offset; // Remaining items (including this tile) 342 343 if (num_remaining > TILE_ITEMS) 344 { 345 // Not last tile 346 ConsumeTile<false>(num_remaining, tile_idx, tile_offset, tile_state); 347 } 348 else if (num_remaining > 0) 349 { 350 // Last tile 351 ConsumeTile<true>(num_remaining, tile_idx, tile_offset, tile_state); 352 } 353 } 354 355 356 //--------------------------------------------------------------------- 357 // Scan an sequence of consecutive tiles (independent of other thread blocks) 358 //--------------------------------------------------------------------- 359 360 /** 361 * Process a tile of input 362 */ 363 template < 364 bool IS_FIRST_TILE, 365 bool IS_LAST_TILE> ConsumeTilecub::AgentScan366 __device__ __forceinline__ void ConsumeTile( 367 OffsetT tile_offset, ///< Tile offset 368 RunningPrefixCallbackOp& prefix_op, ///< Running prefix operator 369 int valid_items = TILE_ITEMS) ///< Number of valid items in the tile 370 { 371 // Load items 372 OutputT items[ITEMS_PER_THREAD]; 373 374 if (IS_LAST_TILE) 375 BlockLoadT(temp_storage.load).Load(d_in + tile_offset, items, valid_items); 376 else 377 BlockLoadT(temp_storage.load).Load(d_in + tile_offset, items); 378 379 CTA_SYNC(); 380 381 // Block scan 382 if (IS_FIRST_TILE) 383 { 384 OutputT block_aggregate; 385 ScanTile(items, init_value, scan_op, block_aggregate, Int2Type<IS_INCLUSIVE>()); 386 prefix_op.running_total = block_aggregate; 387 } 388 else 389 { 390 ScanTile(items, scan_op, prefix_op, Int2Type<IS_INCLUSIVE>()); 391 } 392 393 CTA_SYNC(); 394 395 // Store items 396 if (IS_LAST_TILE) 397 BlockStoreT(temp_storage.store).Store(d_out + tile_offset, items, valid_items); 398 else 399 BlockStoreT(temp_storage.store).Store(d_out + tile_offset, items); 400 } 401 402 403 /** 404 * Scan a consecutive share of input tiles 405 */ ConsumeRangecub::AgentScan406 __device__ __forceinline__ void ConsumeRange( 407 OffsetT range_offset, ///< [in] Threadblock begin offset (inclusive) 408 OffsetT range_end) ///< [in] Threadblock end offset (exclusive) 409 { 410 BlockScanRunningPrefixOp<OutputT, ScanOpT> prefix_op(scan_op); 411 412 if (range_offset + TILE_ITEMS <= range_end) 413 { 414 // Consume first tile of input (full) 415 ConsumeTile<true, true>(range_offset, prefix_op); 416 range_offset += TILE_ITEMS; 417 418 // Consume subsequent full tiles of input 419 while (range_offset + TILE_ITEMS <= range_end) 420 { 421 ConsumeTile<false, true>(range_offset, prefix_op); 422 range_offset += TILE_ITEMS; 423 } 424 425 // Consume a partially-full tile 426 if (range_offset < range_end) 427 { 428 int valid_items = range_end - range_offset; 429 ConsumeTile<false, false>(range_offset, prefix_op, valid_items); 430 } 431 } 432 else 433 { 434 // Consume the first tile of input (partially-full) 435 int valid_items = range_end - range_offset; 436 ConsumeTile<true, false>(range_offset, prefix_op, valid_items); 437 } 438 } 439 440 441 /** 442 * Scan a consecutive share of input tiles, seeded with the specified prefix value 443 */ ConsumeRangecub::AgentScan444 __device__ __forceinline__ void ConsumeRange( 445 OffsetT range_offset, ///< [in] Threadblock begin offset (inclusive) 446 OffsetT range_end, ///< [in] Threadblock end offset (exclusive) 447 OutputT prefix) ///< [in] The prefix to apply to the scan segment 448 { 449 BlockScanRunningPrefixOp<OutputT, ScanOpT> prefix_op(prefix, scan_op); 450 451 // Consume full tiles of input 452 while (range_offset + TILE_ITEMS <= range_end) 453 { 454 ConsumeTile<true, false>(range_offset, prefix_op); 455 range_offset += TILE_ITEMS; 456 } 457 458 // Consume a partially-full tile 459 if (range_offset < range_end) 460 { 461 int valid_items = range_end - range_offset; 462 ConsumeTile<false, false>(range_offset, prefix_op, valid_items); 463 } 464 } 465 466 }; 467 468 469 } // CUB namespace 470 CUB_NS_POSTFIX // Optional outer namespace(s) 471 472