1 /****************************************************************************** 2 * Copyright (c) 2011, Duane Merrill. All rights reserved. 3 * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. 4 * 5 * Redistribution and use in source and binary forms, with or without 6 * modification, are permitted provided that the following conditions are met: 7 * * Redistributions of source code must retain the above copyright 8 * notice, this list of conditions and the following disclaimer. 9 * * Redistributions in binary form must reproduce the above copyright 10 * notice, this list of conditions and the following disclaimer in the 11 * documentation and/or other materials provided with the distribution. 12 * * Neither the name of the NVIDIA CORPORATION nor the 13 * names of its contributors may be used to endorse or promote products 14 * derived from this software without specific prior written permission. 15 * 16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 17 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 18 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 20 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 21 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 22 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 23 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 25 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 * 27 ******************************************************************************/ 28 29 /** 30 * \file 31 * cub::AgentRle implements a stateful abstraction of CUDA thread blocks for participating in device-wide run-length-encode. 32 */ 33 34 #pragma once 35 36 #include <iterator> 37 38 #include "single_pass_scan_operators.cuh" 39 #include "../block/block_load.cuh" 40 #include "../block/block_store.cuh" 41 #include "../block/block_scan.cuh" 42 #include "../block/block_exchange.cuh" 43 #include "../block/block_discontinuity.cuh" 44 #include "../grid/grid_queue.cuh" 45 #include "../iterator/cache_modified_input_iterator.cuh" 46 #include "../iterator/constant_input_iterator.cuh" 47 #include "../util_namespace.cuh" 48 49 /// Optional outer namespace(s) 50 CUB_NS_PREFIX 51 52 /// CUB namespace 53 namespace cub { 54 55 56 /****************************************************************************** 57 * Tuning policy types 58 ******************************************************************************/ 59 60 /** 61 * Parameterizable tuning policy type for AgentRle 62 */ 63 template < 64 int _BLOCK_THREADS, ///< Threads per thread block 65 int _ITEMS_PER_THREAD, ///< Items per thread (per tile of input) 66 BlockLoadAlgorithm _LOAD_ALGORITHM, ///< The BlockLoad algorithm to use 67 CacheLoadModifier _LOAD_MODIFIER, ///< Cache load modifier for reading input elements 68 bool _STORE_WARP_TIME_SLICING, ///< Whether or not only one warp's worth of shared memory should be allocated and time-sliced among block-warps during any store-related data transpositions (versus each warp having its own storage) 69 BlockScanAlgorithm _SCAN_ALGORITHM> ///< The BlockScan algorithm to use 70 struct AgentRlePolicy 71 { 72 enum 73 { 74 BLOCK_THREADS = _BLOCK_THREADS, ///< Threads per thread block 75 ITEMS_PER_THREAD = _ITEMS_PER_THREAD, ///< Items per thread (per tile of input) 76 STORE_WARP_TIME_SLICING = _STORE_WARP_TIME_SLICING, ///< Whether or not only one warp's worth of shared memory should be allocated and time-sliced among block-warps during any store-related data transpositions (versus each warp having its own storage) 77 }; 78 79 static const BlockLoadAlgorithm LOAD_ALGORITHM = _LOAD_ALGORITHM; ///< The BlockLoad algorithm to use 80 static const CacheLoadModifier LOAD_MODIFIER = _LOAD_MODIFIER; ///< Cache load modifier for reading input elements 81 static const BlockScanAlgorithm SCAN_ALGORITHM = _SCAN_ALGORITHM; ///< The BlockScan algorithm to use 82 }; 83 84 85 86 87 88 /****************************************************************************** 89 * Thread block abstractions 90 ******************************************************************************/ 91 92 /** 93 * \brief AgentRle implements a stateful abstraction of CUDA thread blocks for participating in device-wide run-length-encode 94 */ 95 template < 96 typename AgentRlePolicyT, ///< Parameterized AgentRlePolicyT tuning policy type 97 typename InputIteratorT, ///< Random-access input iterator type for data 98 typename OffsetsOutputIteratorT, ///< Random-access output iterator type for offset values 99 typename LengthsOutputIteratorT, ///< Random-access output iterator type for length values 100 typename EqualityOpT, ///< T equality operator type 101 typename OffsetT> ///< Signed integer type for global offsets 102 struct AgentRle 103 { 104 //--------------------------------------------------------------------- 105 // Types and constants 106 //--------------------------------------------------------------------- 107 108 /// The input value type 109 typedef typename std::iterator_traits<InputIteratorT>::value_type T; 110 111 /// The lengths output value type 112 typedef typename If<(Equals<typename std::iterator_traits<LengthsOutputIteratorT>::value_type, void>::VALUE), // LengthT = (if output iterator's value type is void) ? 113 OffsetT, // ... then the OffsetT type, 114 typename std::iterator_traits<LengthsOutputIteratorT>::value_type>::Type LengthT; // ... else the output iterator's value type 115 116 /// Tuple type for scanning (pairs run-length and run-index) 117 typedef KeyValuePair<OffsetT, LengthT> LengthOffsetPair; 118 119 /// Tile status descriptor interface type 120 typedef ReduceByKeyScanTileState<LengthT, OffsetT> ScanTileStateT; 121 122 // Constants 123 enum 124 { 125 WARP_THREADS = CUB_WARP_THREADS(PTX_ARCH), 126 BLOCK_THREADS = AgentRlePolicyT::BLOCK_THREADS, 127 ITEMS_PER_THREAD = AgentRlePolicyT::ITEMS_PER_THREAD, 128 WARP_ITEMS = WARP_THREADS * ITEMS_PER_THREAD, 129 TILE_ITEMS = BLOCK_THREADS * ITEMS_PER_THREAD, 130 WARPS = (BLOCK_THREADS + WARP_THREADS - 1) / WARP_THREADS, 131 132 /// Whether or not to sync after loading data 133 SYNC_AFTER_LOAD = (AgentRlePolicyT::LOAD_ALGORITHM != BLOCK_LOAD_DIRECT), 134 135 /// Whether or not only one warp's worth of shared memory should be allocated and time-sliced among block-warps during any store-related data transpositions (versus each warp having its own storage) 136 STORE_WARP_TIME_SLICING = AgentRlePolicyT::STORE_WARP_TIME_SLICING, 137 ACTIVE_EXCHANGE_WARPS = (STORE_WARP_TIME_SLICING) ? 1 : WARPS, 138 }; 139 140 141 /** 142 * Special operator that signals all out-of-bounds items are not equal to everything else, 143 * forcing both (1) the last item to be tail-flagged and (2) all oob items to be marked 144 * trivial. 145 */ 146 template <bool LAST_TILE> 147 struct OobInequalityOp 148 { 149 OffsetT num_remaining; 150 EqualityOpT equality_op; 151 OobInequalityOpcub::AgentRle::OobInequalityOp152 __device__ __forceinline__ OobInequalityOp( 153 OffsetT num_remaining, 154 EqualityOpT equality_op) 155 : 156 num_remaining(num_remaining), 157 equality_op(equality_op) 158 {} 159 160 template <typename Index> operator ()cub::AgentRle::OobInequalityOp161 __host__ __device__ __forceinline__ bool operator()(T first, T second, Index idx) 162 { 163 if (!LAST_TILE || (idx < num_remaining)) 164 return !equality_op(first, second); 165 else 166 return true; 167 } 168 }; 169 170 171 // Cache-modified Input iterator wrapper type (for applying cache modifier) for data 172 typedef typename If<IsPointer<InputIteratorT>::VALUE, 173 CacheModifiedInputIterator<AgentRlePolicyT::LOAD_MODIFIER, T, OffsetT>, // Wrap the native input pointer with CacheModifiedVLengthnputIterator 174 InputIteratorT>::Type // Directly use the supplied input iterator type 175 WrappedInputIteratorT; 176 177 // Parameterized BlockLoad type for data 178 typedef BlockLoad< 179 T, 180 AgentRlePolicyT::BLOCK_THREADS, 181 AgentRlePolicyT::ITEMS_PER_THREAD, 182 AgentRlePolicyT::LOAD_ALGORITHM> 183 BlockLoadT; 184 185 // Parameterized BlockDiscontinuity type for data 186 typedef BlockDiscontinuity<T, BLOCK_THREADS> BlockDiscontinuityT; 187 188 // Parameterized WarpScan type 189 typedef WarpScan<LengthOffsetPair> WarpScanPairs; 190 191 // Reduce-length-by-run scan operator 192 typedef ReduceBySegmentOp<cub::Sum> ReduceBySegmentOpT; 193 194 // Callback type for obtaining tile prefix during block scan 195 typedef TilePrefixCallbackOp< 196 LengthOffsetPair, 197 ReduceBySegmentOpT, 198 ScanTileStateT> 199 TilePrefixCallbackOpT; 200 201 // Warp exchange types 202 typedef WarpExchange<LengthOffsetPair, ITEMS_PER_THREAD> WarpExchangePairs; 203 204 typedef typename If<STORE_WARP_TIME_SLICING, typename WarpExchangePairs::TempStorage, NullType>::Type WarpExchangePairsStorage; 205 206 typedef WarpExchange<OffsetT, ITEMS_PER_THREAD> WarpExchangeOffsets; 207 typedef WarpExchange<LengthT, ITEMS_PER_THREAD> WarpExchangeLengths; 208 209 typedef LengthOffsetPair WarpAggregates[WARPS]; 210 211 // Shared memory type for this thread block 212 struct _TempStorage 213 { 214 // Aliasable storage layout 215 union Aliasable 216 { 217 struct 218 { 219 typename BlockDiscontinuityT::TempStorage discontinuity; // Smem needed for discontinuity detection 220 typename WarpScanPairs::TempStorage warp_scan[WARPS]; // Smem needed for warp-synchronous scans 221 Uninitialized<LengthOffsetPair[WARPS]> warp_aggregates; // Smem needed for sharing warp-wide aggregates 222 typename TilePrefixCallbackOpT::TempStorage prefix; // Smem needed for cooperative prefix callback 223 }; 224 225 // Smem needed for input loading 226 typename BlockLoadT::TempStorage load; 227 228 // Aliasable layout needed for two-phase scatter 229 union ScatterAliasable 230 { 231 unsigned long long align; 232 WarpExchangePairsStorage exchange_pairs[ACTIVE_EXCHANGE_WARPS]; 233 typename WarpExchangeOffsets::TempStorage exchange_offsets[ACTIVE_EXCHANGE_WARPS]; 234 typename WarpExchangeLengths::TempStorage exchange_lengths[ACTIVE_EXCHANGE_WARPS]; 235 236 } scatter_aliasable; 237 238 } aliasable; 239 240 OffsetT tile_idx; // Shared tile index 241 LengthOffsetPair tile_inclusive; // Inclusive tile prefix 242 LengthOffsetPair tile_exclusive; // Exclusive tile prefix 243 }; 244 245 // Alias wrapper allowing storage to be unioned 246 struct TempStorage : Uninitialized<_TempStorage> {}; 247 248 249 //--------------------------------------------------------------------- 250 // Per-thread fields 251 //--------------------------------------------------------------------- 252 253 _TempStorage& temp_storage; ///< Reference to temp_storage 254 255 WrappedInputIteratorT d_in; ///< Pointer to input sequence of data items 256 OffsetsOutputIteratorT d_offsets_out; ///< Input run offsets 257 LengthsOutputIteratorT d_lengths_out; ///< Output run lengths 258 259 EqualityOpT equality_op; ///< T equality operator 260 ReduceBySegmentOpT scan_op; ///< Reduce-length-by-flag scan operator 261 OffsetT num_items; ///< Total number of input items 262 263 264 //--------------------------------------------------------------------- 265 // Constructor 266 //--------------------------------------------------------------------- 267 268 // Constructor 269 __device__ __forceinline__ AgentRlecub::AgentRle270 AgentRle( 271 TempStorage &temp_storage, ///< [in] Reference to temp_storage 272 InputIteratorT d_in, ///< [in] Pointer to input sequence of data items 273 OffsetsOutputIteratorT d_offsets_out, ///< [out] Pointer to output sequence of run offsets 274 LengthsOutputIteratorT d_lengths_out, ///< [out] Pointer to output sequence of run lengths 275 EqualityOpT equality_op, ///< [in] T equality operator 276 OffsetT num_items) ///< [in] Total number of input items 277 : 278 temp_storage(temp_storage.Alias()), 279 d_in(d_in), 280 d_offsets_out(d_offsets_out), 281 d_lengths_out(d_lengths_out), 282 equality_op(equality_op), 283 scan_op(cub::Sum()), 284 num_items(num_items) 285 {} 286 287 288 //--------------------------------------------------------------------- 289 // Utility methods for initializing the selections 290 //--------------------------------------------------------------------- 291 292 template <bool FIRST_TILE, bool LAST_TILE> InitializeSelectionscub::AgentRle293 __device__ __forceinline__ void InitializeSelections( 294 OffsetT tile_offset, 295 OffsetT num_remaining, 296 T (&items)[ITEMS_PER_THREAD], 297 LengthOffsetPair (&lengths_and_num_runs)[ITEMS_PER_THREAD]) 298 { 299 bool head_flags[ITEMS_PER_THREAD]; 300 bool tail_flags[ITEMS_PER_THREAD]; 301 302 OobInequalityOp<LAST_TILE> inequality_op(num_remaining, equality_op); 303 304 if (FIRST_TILE && LAST_TILE) 305 { 306 // First-and-last-tile always head-flags the first item and tail-flags the last item 307 308 BlockDiscontinuityT(temp_storage.aliasable.discontinuity).FlagHeadsAndTails( 309 head_flags, tail_flags, items, inequality_op); 310 } 311 else if (FIRST_TILE) 312 { 313 // First-tile always head-flags the first item 314 315 // Get the first item from the next tile 316 T tile_successor_item; 317 if (threadIdx.x == BLOCK_THREADS - 1) 318 tile_successor_item = d_in[tile_offset + TILE_ITEMS]; 319 320 BlockDiscontinuityT(temp_storage.aliasable.discontinuity).FlagHeadsAndTails( 321 head_flags, tail_flags, tile_successor_item, items, inequality_op); 322 } 323 else if (LAST_TILE) 324 { 325 // Last-tile always flags the last item 326 327 // Get the last item from the previous tile 328 T tile_predecessor_item; 329 if (threadIdx.x == 0) 330 tile_predecessor_item = d_in[tile_offset - 1]; 331 332 BlockDiscontinuityT(temp_storage.aliasable.discontinuity).FlagHeadsAndTails( 333 head_flags, tile_predecessor_item, tail_flags, items, inequality_op); 334 } 335 else 336 { 337 // Get the first item from the next tile 338 T tile_successor_item; 339 if (threadIdx.x == BLOCK_THREADS - 1) 340 tile_successor_item = d_in[tile_offset + TILE_ITEMS]; 341 342 // Get the last item from the previous tile 343 T tile_predecessor_item; 344 if (threadIdx.x == 0) 345 tile_predecessor_item = d_in[tile_offset - 1]; 346 347 BlockDiscontinuityT(temp_storage.aliasable.discontinuity).FlagHeadsAndTails( 348 head_flags, tile_predecessor_item, tail_flags, tile_successor_item, items, inequality_op); 349 } 350 351 // Zip counts and runs 352 #pragma unroll 353 for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) 354 { 355 lengths_and_num_runs[ITEM].key = head_flags[ITEM] && (!tail_flags[ITEM]); 356 lengths_and_num_runs[ITEM].value = ((!head_flags[ITEM]) || (!tail_flags[ITEM])); 357 } 358 } 359 360 //--------------------------------------------------------------------- 361 // Scan utility methods 362 //--------------------------------------------------------------------- 363 364 /** 365 * Scan of allocations 366 */ WarpScanAllocationscub::AgentRle367 __device__ __forceinline__ void WarpScanAllocations( 368 LengthOffsetPair &tile_aggregate, 369 LengthOffsetPair &warp_aggregate, 370 LengthOffsetPair &warp_exclusive_in_tile, 371 LengthOffsetPair &thread_exclusive_in_warp, 372 LengthOffsetPair (&lengths_and_num_runs)[ITEMS_PER_THREAD]) 373 { 374 // Perform warpscans 375 unsigned int warp_id = ((WARPS == 1) ? 0 : threadIdx.x / WARP_THREADS); 376 int lane_id = LaneId(); 377 378 LengthOffsetPair identity; 379 identity.key = 0; 380 identity.value = 0; 381 382 LengthOffsetPair thread_inclusive; 383 LengthOffsetPair thread_aggregate = internal::ThreadReduce(lengths_and_num_runs, scan_op); 384 WarpScanPairs(temp_storage.aliasable.warp_scan[warp_id]).Scan( 385 thread_aggregate, 386 thread_inclusive, 387 thread_exclusive_in_warp, 388 identity, 389 scan_op); 390 391 // Last lane in each warp shares its warp-aggregate 392 if (lane_id == WARP_THREADS - 1) 393 temp_storage.aliasable.warp_aggregates.Alias()[warp_id] = thread_inclusive; 394 395 CTA_SYNC(); 396 397 // Accumulate total selected and the warp-wide prefix 398 warp_exclusive_in_tile = identity; 399 warp_aggregate = temp_storage.aliasable.warp_aggregates.Alias()[warp_id]; 400 tile_aggregate = temp_storage.aliasable.warp_aggregates.Alias()[0]; 401 402 #pragma unroll 403 for (int WARP = 1; WARP < WARPS; ++WARP) 404 { 405 if (warp_id == WARP) 406 warp_exclusive_in_tile = tile_aggregate; 407 408 tile_aggregate = scan_op(tile_aggregate, temp_storage.aliasable.warp_aggregates.Alias()[WARP]); 409 } 410 } 411 412 413 //--------------------------------------------------------------------- 414 // Utility methods for scattering selections 415 //--------------------------------------------------------------------- 416 417 /** 418 * Two-phase scatter, specialized for warp time-slicing 419 */ 420 template <bool FIRST_TILE> ScatterTwoPhasecub::AgentRle421 __device__ __forceinline__ void ScatterTwoPhase( 422 OffsetT tile_num_runs_exclusive_in_global, 423 OffsetT warp_num_runs_aggregate, 424 OffsetT warp_num_runs_exclusive_in_tile, 425 OffsetT (&thread_num_runs_exclusive_in_warp)[ITEMS_PER_THREAD], 426 LengthOffsetPair (&lengths_and_offsets)[ITEMS_PER_THREAD], 427 Int2Type<true> is_warp_time_slice) 428 { 429 unsigned int warp_id = ((WARPS == 1) ? 0 : threadIdx.x / WARP_THREADS); 430 int lane_id = LaneId(); 431 432 // Locally compact items within the warp (first warp) 433 if (warp_id == 0) 434 { 435 WarpExchangePairs(temp_storage.aliasable.scatter_aliasable.exchange_pairs[0]).ScatterToStriped( 436 lengths_and_offsets, thread_num_runs_exclusive_in_warp); 437 } 438 439 // Locally compact items within the warp (remaining warps) 440 #pragma unroll 441 for (int SLICE = 1; SLICE < WARPS; ++SLICE) 442 { 443 CTA_SYNC(); 444 445 if (warp_id == SLICE) 446 { 447 WarpExchangePairs(temp_storage.aliasable.scatter_aliasable.exchange_pairs[0]).ScatterToStriped( 448 lengths_and_offsets, thread_num_runs_exclusive_in_warp); 449 } 450 } 451 452 // Global scatter 453 #pragma unroll 454 for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) 455 { 456 if ((ITEM * WARP_THREADS) < warp_num_runs_aggregate - lane_id) 457 { 458 OffsetT item_offset = 459 tile_num_runs_exclusive_in_global + 460 warp_num_runs_exclusive_in_tile + 461 (ITEM * WARP_THREADS) + lane_id; 462 463 // Scatter offset 464 d_offsets_out[item_offset] = lengths_and_offsets[ITEM].key; 465 466 // Scatter length if not the first (global) length 467 if ((!FIRST_TILE) || (ITEM != 0) || (threadIdx.x > 0)) 468 { 469 d_lengths_out[item_offset - 1] = lengths_and_offsets[ITEM].value; 470 } 471 } 472 } 473 } 474 475 476 /** 477 * Two-phase scatter 478 */ 479 template <bool FIRST_TILE> ScatterTwoPhasecub::AgentRle480 __device__ __forceinline__ void ScatterTwoPhase( 481 OffsetT tile_num_runs_exclusive_in_global, 482 OffsetT warp_num_runs_aggregate, 483 OffsetT warp_num_runs_exclusive_in_tile, 484 OffsetT (&thread_num_runs_exclusive_in_warp)[ITEMS_PER_THREAD], 485 LengthOffsetPair (&lengths_and_offsets)[ITEMS_PER_THREAD], 486 Int2Type<false> is_warp_time_slice) 487 { 488 unsigned int warp_id = ((WARPS == 1) ? 0 : threadIdx.x / WARP_THREADS); 489 int lane_id = LaneId(); 490 491 // Unzip 492 OffsetT run_offsets[ITEMS_PER_THREAD]; 493 LengthT run_lengths[ITEMS_PER_THREAD]; 494 495 #pragma unroll 496 for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) 497 { 498 run_offsets[ITEM] = lengths_and_offsets[ITEM].key; 499 run_lengths[ITEM] = lengths_and_offsets[ITEM].value; 500 } 501 502 WarpExchangeOffsets(temp_storage.aliasable.scatter_aliasable.exchange_offsets[warp_id]).ScatterToStriped( 503 run_offsets, thread_num_runs_exclusive_in_warp); 504 505 WARP_SYNC(0xffffffff); 506 507 WarpExchangeLengths(temp_storage.aliasable.scatter_aliasable.exchange_lengths[warp_id]).ScatterToStriped( 508 run_lengths, thread_num_runs_exclusive_in_warp); 509 510 // Global scatter 511 #pragma unroll 512 for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) 513 { 514 if ((ITEM * WARP_THREADS) + lane_id < warp_num_runs_aggregate) 515 { 516 OffsetT item_offset = 517 tile_num_runs_exclusive_in_global + 518 warp_num_runs_exclusive_in_tile + 519 (ITEM * WARP_THREADS) + lane_id; 520 521 // Scatter offset 522 d_offsets_out[item_offset] = run_offsets[ITEM]; 523 524 // Scatter length if not the first (global) length 525 if ((!FIRST_TILE) || (ITEM != 0) || (threadIdx.x > 0)) 526 { 527 d_lengths_out[item_offset - 1] = run_lengths[ITEM]; 528 } 529 } 530 } 531 } 532 533 534 /** 535 * Direct scatter 536 */ 537 template <bool FIRST_TILE> ScatterDirectcub::AgentRle538 __device__ __forceinline__ void ScatterDirect( 539 OffsetT tile_num_runs_exclusive_in_global, 540 OffsetT warp_num_runs_aggregate, 541 OffsetT warp_num_runs_exclusive_in_tile, 542 OffsetT (&thread_num_runs_exclusive_in_warp)[ITEMS_PER_THREAD], 543 LengthOffsetPair (&lengths_and_offsets)[ITEMS_PER_THREAD]) 544 { 545 #pragma unroll 546 for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) 547 { 548 if (thread_num_runs_exclusive_in_warp[ITEM] < warp_num_runs_aggregate) 549 { 550 OffsetT item_offset = 551 tile_num_runs_exclusive_in_global + 552 warp_num_runs_exclusive_in_tile + 553 thread_num_runs_exclusive_in_warp[ITEM]; 554 555 // Scatter offset 556 d_offsets_out[item_offset] = lengths_and_offsets[ITEM].key; 557 558 // Scatter length if not the first (global) length 559 if (item_offset >= 1) 560 { 561 d_lengths_out[item_offset - 1] = lengths_and_offsets[ITEM].value; 562 } 563 } 564 } 565 } 566 567 568 /** 569 * Scatter 570 */ 571 template <bool FIRST_TILE> Scattercub::AgentRle572 __device__ __forceinline__ void Scatter( 573 OffsetT tile_num_runs_aggregate, 574 OffsetT tile_num_runs_exclusive_in_global, 575 OffsetT warp_num_runs_aggregate, 576 OffsetT warp_num_runs_exclusive_in_tile, 577 OffsetT (&thread_num_runs_exclusive_in_warp)[ITEMS_PER_THREAD], 578 LengthOffsetPair (&lengths_and_offsets)[ITEMS_PER_THREAD]) 579 { 580 if ((ITEMS_PER_THREAD == 1) || (tile_num_runs_aggregate < BLOCK_THREADS)) 581 { 582 // Direct scatter if the warp has any items 583 if (warp_num_runs_aggregate) 584 { 585 ScatterDirect<FIRST_TILE>( 586 tile_num_runs_exclusive_in_global, 587 warp_num_runs_aggregate, 588 warp_num_runs_exclusive_in_tile, 589 thread_num_runs_exclusive_in_warp, 590 lengths_and_offsets); 591 } 592 } 593 else 594 { 595 // Scatter two phase 596 ScatterTwoPhase<FIRST_TILE>( 597 tile_num_runs_exclusive_in_global, 598 warp_num_runs_aggregate, 599 warp_num_runs_exclusive_in_tile, 600 thread_num_runs_exclusive_in_warp, 601 lengths_and_offsets, 602 Int2Type<STORE_WARP_TIME_SLICING>()); 603 } 604 } 605 606 607 608 //--------------------------------------------------------------------- 609 // Cooperatively scan a device-wide sequence of tiles with other CTAs 610 //--------------------------------------------------------------------- 611 612 /** 613 * Process a tile of input (dynamic chained scan) 614 */ 615 template < 616 bool LAST_TILE> ConsumeTilecub::AgentRle617 __device__ __forceinline__ LengthOffsetPair ConsumeTile( 618 OffsetT num_items, ///< Total number of global input items 619 OffsetT num_remaining, ///< Number of global input items remaining (including this tile) 620 int tile_idx, ///< Tile index 621 OffsetT tile_offset, ///< Tile offset 622 ScanTileStateT &tile_status) ///< Global list of tile status 623 { 624 if (tile_idx == 0) 625 { 626 // First tile 627 628 // Load items 629 T items[ITEMS_PER_THREAD]; 630 if (LAST_TILE) 631 BlockLoadT(temp_storage.aliasable.load).Load(d_in + tile_offset, items, num_remaining, T()); 632 else 633 BlockLoadT(temp_storage.aliasable.load).Load(d_in + tile_offset, items); 634 635 if (SYNC_AFTER_LOAD) 636 CTA_SYNC(); 637 638 // Set flags 639 LengthOffsetPair lengths_and_num_runs[ITEMS_PER_THREAD]; 640 641 InitializeSelections<true, LAST_TILE>( 642 tile_offset, 643 num_remaining, 644 items, 645 lengths_and_num_runs); 646 647 // Exclusive scan of lengths and runs 648 LengthOffsetPair tile_aggregate; 649 LengthOffsetPair warp_aggregate; 650 LengthOffsetPair warp_exclusive_in_tile; 651 LengthOffsetPair thread_exclusive_in_warp; 652 653 WarpScanAllocations( 654 tile_aggregate, 655 warp_aggregate, 656 warp_exclusive_in_tile, 657 thread_exclusive_in_warp, 658 lengths_and_num_runs); 659 660 // Update tile status if this is not the last tile 661 if (!LAST_TILE && (threadIdx.x == 0)) 662 tile_status.SetInclusive(0, tile_aggregate); 663 664 // Update thread_exclusive_in_warp to fold in warp run-length 665 if (thread_exclusive_in_warp.key == 0) 666 thread_exclusive_in_warp.value += warp_exclusive_in_tile.value; 667 668 LengthOffsetPair lengths_and_offsets[ITEMS_PER_THREAD]; 669 OffsetT thread_num_runs_exclusive_in_warp[ITEMS_PER_THREAD]; 670 LengthOffsetPair lengths_and_num_runs2[ITEMS_PER_THREAD]; 671 672 // Downsweep scan through lengths_and_num_runs 673 internal::ThreadScanExclusive(lengths_and_num_runs, lengths_and_num_runs2, scan_op, thread_exclusive_in_warp); 674 675 // Zip 676 677 #pragma unroll 678 for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) 679 { 680 lengths_and_offsets[ITEM].value = lengths_and_num_runs2[ITEM].value; 681 lengths_and_offsets[ITEM].key = tile_offset + (threadIdx.x * ITEMS_PER_THREAD) + ITEM; 682 thread_num_runs_exclusive_in_warp[ITEM] = (lengths_and_num_runs[ITEM].key) ? 683 lengths_and_num_runs2[ITEM].key : // keep 684 WARP_THREADS * ITEMS_PER_THREAD; // discard 685 } 686 687 OffsetT tile_num_runs_aggregate = tile_aggregate.key; 688 OffsetT tile_num_runs_exclusive_in_global = 0; 689 OffsetT warp_num_runs_aggregate = warp_aggregate.key; 690 OffsetT warp_num_runs_exclusive_in_tile = warp_exclusive_in_tile.key; 691 692 // Scatter 693 Scatter<true>( 694 tile_num_runs_aggregate, 695 tile_num_runs_exclusive_in_global, 696 warp_num_runs_aggregate, 697 warp_num_runs_exclusive_in_tile, 698 thread_num_runs_exclusive_in_warp, 699 lengths_and_offsets); 700 701 // Return running total (inclusive of this tile) 702 return tile_aggregate; 703 } 704 else 705 { 706 // Not first tile 707 708 // Load items 709 T items[ITEMS_PER_THREAD]; 710 if (LAST_TILE) 711 BlockLoadT(temp_storage.aliasable.load).Load(d_in + tile_offset, items, num_remaining, T()); 712 else 713 BlockLoadT(temp_storage.aliasable.load).Load(d_in + tile_offset, items); 714 715 if (SYNC_AFTER_LOAD) 716 CTA_SYNC(); 717 718 // Set flags 719 LengthOffsetPair lengths_and_num_runs[ITEMS_PER_THREAD]; 720 721 InitializeSelections<false, LAST_TILE>( 722 tile_offset, 723 num_remaining, 724 items, 725 lengths_and_num_runs); 726 727 // Exclusive scan of lengths and runs 728 LengthOffsetPair tile_aggregate; 729 LengthOffsetPair warp_aggregate; 730 LengthOffsetPair warp_exclusive_in_tile; 731 LengthOffsetPair thread_exclusive_in_warp; 732 733 WarpScanAllocations( 734 tile_aggregate, 735 warp_aggregate, 736 warp_exclusive_in_tile, 737 thread_exclusive_in_warp, 738 lengths_and_num_runs); 739 740 // First warp computes tile prefix in lane 0 741 TilePrefixCallbackOpT prefix_op(tile_status, temp_storage.aliasable.prefix, Sum(), tile_idx); 742 unsigned int warp_id = ((WARPS == 1) ? 0 : threadIdx.x / WARP_THREADS); 743 if (warp_id == 0) 744 { 745 prefix_op(tile_aggregate); 746 if (threadIdx.x == 0) 747 temp_storage.tile_exclusive = prefix_op.exclusive_prefix; 748 } 749 750 CTA_SYNC(); 751 752 LengthOffsetPair tile_exclusive_in_global = temp_storage.tile_exclusive; 753 754 // Update thread_exclusive_in_warp to fold in warp and tile run-lengths 755 LengthOffsetPair thread_exclusive = scan_op(tile_exclusive_in_global, warp_exclusive_in_tile); 756 if (thread_exclusive_in_warp.key == 0) 757 thread_exclusive_in_warp.value += thread_exclusive.value; 758 759 // Downsweep scan through lengths_and_num_runs 760 LengthOffsetPair lengths_and_num_runs2[ITEMS_PER_THREAD]; 761 LengthOffsetPair lengths_and_offsets[ITEMS_PER_THREAD]; 762 OffsetT thread_num_runs_exclusive_in_warp[ITEMS_PER_THREAD]; 763 764 internal::ThreadScanExclusive(lengths_and_num_runs, lengths_and_num_runs2, scan_op, thread_exclusive_in_warp); 765 766 // Zip 767 #pragma unroll 768 for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++) 769 { 770 lengths_and_offsets[ITEM].value = lengths_and_num_runs2[ITEM].value; 771 lengths_and_offsets[ITEM].key = tile_offset + (threadIdx.x * ITEMS_PER_THREAD) + ITEM; 772 thread_num_runs_exclusive_in_warp[ITEM] = (lengths_and_num_runs[ITEM].key) ? 773 lengths_and_num_runs2[ITEM].key : // keep 774 WARP_THREADS * ITEMS_PER_THREAD; // discard 775 } 776 777 OffsetT tile_num_runs_aggregate = tile_aggregate.key; 778 OffsetT tile_num_runs_exclusive_in_global = tile_exclusive_in_global.key; 779 OffsetT warp_num_runs_aggregate = warp_aggregate.key; 780 OffsetT warp_num_runs_exclusive_in_tile = warp_exclusive_in_tile.key; 781 782 // Scatter 783 Scatter<false>( 784 tile_num_runs_aggregate, 785 tile_num_runs_exclusive_in_global, 786 warp_num_runs_aggregate, 787 warp_num_runs_exclusive_in_tile, 788 thread_num_runs_exclusive_in_warp, 789 lengths_and_offsets); 790 791 // Return running total (inclusive of this tile) 792 return prefix_op.inclusive_prefix; 793 } 794 } 795 796 797 /** 798 * Scan tiles of items as part of a dynamic chained scan 799 */ 800 template <typename NumRunsIteratorT> ///< Output iterator type for recording number of items selected ConsumeRangecub::AgentRle801 __device__ __forceinline__ void ConsumeRange( 802 int num_tiles, ///< Total number of input tiles 803 ScanTileStateT& tile_status, ///< Global list of tile status 804 NumRunsIteratorT d_num_runs_out) ///< Output pointer for total number of runs identified 805 { 806 // Blocks are launched in increasing order, so just assign one tile per block 807 int tile_idx = (blockIdx.x * gridDim.y) + blockIdx.y; // Current tile index 808 OffsetT tile_offset = tile_idx * TILE_ITEMS; // Global offset for the current tile 809 OffsetT num_remaining = num_items - tile_offset; // Remaining items (including this tile) 810 811 if (tile_idx < num_tiles - 1) 812 { 813 // Not the last tile (full) 814 ConsumeTile<false>(num_items, num_remaining, tile_idx, tile_offset, tile_status); 815 } 816 else if (num_remaining > 0) 817 { 818 // The last tile (possibly partially-full) 819 LengthOffsetPair running_total = ConsumeTile<true>(num_items, num_remaining, tile_idx, tile_offset, tile_status); 820 821 if (threadIdx.x == 0) 822 { 823 // Output the total number of items selected 824 *d_num_runs_out = running_total.key; 825 826 // The inclusive prefix contains accumulated length reduction for the last run 827 if (running_total.key > 0) 828 d_lengths_out[running_total.key - 1] = running_total.value; 829 } 830 } 831 } 832 }; 833 834 835 } // CUB namespace 836 CUB_NS_POSTFIX // Optional outer namespace(s) 837 838