1 /****************************************************************************** 2 * Copyright (c) 2011, Duane Merrill. All rights reserved. 3 * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. 4 * 5 * Redistribution and use in source and binary forms, with or without 6 * modification, are permitted provided that the following conditions are met: 7 * * Redistributions of source code must retain the above copyright 8 * notice, this list of conditions and the following disclaimer. 9 * * Redistributions in binary form must reproduce the above copyright 10 * notice, this list of conditions and the following disclaimer in the 11 * documentation and/or other materials provided with the distribution. 12 * * Neither the name of the NVIDIA CORPORATION nor the 13 * names of its contributors may be used to endorse or promote products 14 * derived from this software without specific prior written permission. 15 * 16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 17 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 18 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 20 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 21 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 22 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 23 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 25 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 * 27 ******************************************************************************/ 28 29 /** 30 * \file 31 * cub::AgentSelectIf implements a stateful abstraction of CUDA thread blocks for participating in device-wide select. 32 */ 33 34 #pragma once 35 36 #include <iterator> 37 38 #include "single_pass_scan_operators.cuh" 39 #include "../block/block_load.cuh" 40 #include "../block/block_store.cuh" 41 #include "../block/block_scan.cuh" 42 #include "../block/block_exchange.cuh" 43 #include "../block/block_discontinuity.cuh" 44 #include "../grid/grid_queue.cuh" 45 #include "../iterator/cache_modified_input_iterator.cuh" 46 #include "../util_namespace.cuh" 47 48 /// Optional outer namespace(s) 49 CUB_NS_PREFIX 50 51 /// CUB namespace 52 namespace cub { 53 54 55 /****************************************************************************** 56 * Tuning policy types 57 ******************************************************************************/ 58 59 /** 60 * Parameterizable tuning policy type for AgentSelectIf 61 */ 62 template < 63 int _BLOCK_THREADS, ///< Threads per thread block 64 int _ITEMS_PER_THREAD, ///< Items per thread (per tile of input) 65 BlockLoadAlgorithm _LOAD_ALGORITHM, ///< The BlockLoad algorithm to use 66 CacheLoadModifier _LOAD_MODIFIER, ///< Cache load modifier for reading input elements 67 BlockScanAlgorithm _SCAN_ALGORITHM> ///< The BlockScan algorithm to use 68 struct AgentSelectIfPolicy 69 { 70 enum 71 { 72 BLOCK_THREADS = _BLOCK_THREADS, ///< Threads per thread block 73 ITEMS_PER_THREAD = _ITEMS_PER_THREAD, ///< Items per thread (per tile of input) 74 }; 75 76 static const BlockLoadAlgorithm LOAD_ALGORITHM = _LOAD_ALGORITHM; ///< The BlockLoad algorithm to use 77 static const CacheLoadModifier LOAD_MODIFIER = _LOAD_MODIFIER; ///< Cache load modifier for reading input elements 78 static const BlockScanAlgorithm SCAN_ALGORITHM = _SCAN_ALGORITHM; ///< The BlockScan algorithm to use 79 }; 80 81 82 83 84 /****************************************************************************** 85 * Thread block abstractions 86 ******************************************************************************/ 87 88 89 /** 90 * \brief AgentSelectIf implements a stateful abstraction of CUDA thread blocks for participating in device-wide selection 91 * 92 * Performs functor-based selection if SelectOpT functor type != NullType 93 * Otherwise performs flag-based selection if FlagsInputIterator's value type != NullType 94 * Otherwise performs discontinuity selection (keep unique) 95 */ 96 template < 97 typename AgentSelectIfPolicyT, ///< Parameterized AgentSelectIfPolicy tuning policy type 98 typename InputIteratorT, ///< Random-access input iterator type for selection items 99 typename FlagsInputIteratorT, ///< Random-access input iterator type for selections (NullType* if a selection functor or discontinuity flagging is to be used for selection) 100 typename SelectedOutputIteratorT, ///< Random-access input iterator type for selection_flags items 101 typename SelectOpT, ///< Selection operator type (NullType if selections or discontinuity flagging is to be used for selection) 102 typename EqualityOpT, ///< Equality operator type (NullType if selection functor or selections is to be used for selection) 103 typename OffsetT, ///< Signed integer type for global offsets 104 bool KEEP_REJECTS> ///< Whether or not we push rejected items to the back of the output 105 struct AgentSelectIf 106 { 107 //--------------------------------------------------------------------- 108 // Types and constants 109 //--------------------------------------------------------------------- 110 111 // The input value type 112 typedef typename std::iterator_traits<InputIteratorT>::value_type InputT; 113 114 // The output value type 115 typedef typename If<(Equals<typename std::iterator_traits<SelectedOutputIteratorT>::value_type, void>::VALUE), // OutputT = (if output iterator's value type is void) ? 116 typename std::iterator_traits<InputIteratorT>::value_type, // ... then the input iterator's value type, 117 typename std::iterator_traits<SelectedOutputIteratorT>::value_type>::Type OutputT; // ... else the output iterator's value type 118 119 // The flag value type 120 typedef typename std::iterator_traits<FlagsInputIteratorT>::value_type FlagT; 121 122 // Tile status descriptor interface type 123 typedef ScanTileState<OffsetT> ScanTileStateT; 124 125 // Constants 126 enum 127 { 128 USE_SELECT_OP, 129 USE_SELECT_FLAGS, 130 USE_DISCONTINUITY, 131 132 BLOCK_THREADS = AgentSelectIfPolicyT::BLOCK_THREADS, 133 ITEMS_PER_THREAD = AgentSelectIfPolicyT::ITEMS_PER_THREAD, 134 TILE_ITEMS = BLOCK_THREADS * ITEMS_PER_THREAD, 135 TWO_PHASE_SCATTER = (ITEMS_PER_THREAD > 1), 136 137 SELECT_METHOD = (!Equals<SelectOpT, NullType>::VALUE) ? 138 USE_SELECT_OP : 139 (!Equals<FlagT, NullType>::VALUE) ? 140 USE_SELECT_FLAGS : 141 USE_DISCONTINUITY 142 }; 143 144 // Cache-modified Input iterator wrapper type (for applying cache modifier) for items 145 typedef typename If<IsPointer<InputIteratorT>::VALUE, 146 CacheModifiedInputIterator<AgentSelectIfPolicyT::LOAD_MODIFIER, InputT, OffsetT>, // Wrap the native input pointer with CacheModifiedValuesInputIterator 147 InputIteratorT>::Type // Directly use the supplied input iterator type 148 WrappedInputIteratorT; 149 150 // Cache-modified Input iterator wrapper type (for applying cache modifier) for values 151 typedef typename If<IsPointer<FlagsInputIteratorT>::VALUE, 152 CacheModifiedInputIterator<AgentSelectIfPolicyT::LOAD_MODIFIER, FlagT, OffsetT>, // Wrap the native input pointer with CacheModifiedValuesInputIterator 153 FlagsInputIteratorT>::Type // Directly use the supplied input iterator type 154 WrappedFlagsInputIteratorT; 155 156 // Parameterized BlockLoad type for input data 157 typedef BlockLoad< 158 OutputT, 159 BLOCK_THREADS, 160 ITEMS_PER_THREAD, 161 AgentSelectIfPolicyT::LOAD_ALGORITHM> 162 BlockLoadT; 163 164 // Parameterized BlockLoad type for flags 165 typedef BlockLoad< 166 FlagT, 167 BLOCK_THREADS, 168 ITEMS_PER_THREAD, 169 AgentSelectIfPolicyT::LOAD_ALGORITHM> 170 BlockLoadFlags; 171 172 // Parameterized BlockDiscontinuity type for items 173 typedef BlockDiscontinuity< 174 OutputT, 175 BLOCK_THREADS> 176 BlockDiscontinuityT; 177 178 // Parameterized BlockScan type 179 typedef BlockScan< 180 OffsetT, 181 BLOCK_THREADS, 182 AgentSelectIfPolicyT::SCAN_ALGORITHM> 183 BlockScanT; 184 185 // Callback type for obtaining tile prefix during block scan 186 typedef TilePrefixCallbackOp< 187 OffsetT, 188 cub::Sum, 189 ScanTileStateT> 190 TilePrefixCallbackOpT; 191 192 // Item exchange type 193 typedef OutputT ItemExchangeT[TILE_ITEMS]; 194 195 // Shared memory type for this thread block 196 union _TempStorage 197 { 198 struct 199 { 200 typename BlockScanT::TempStorage scan; // Smem needed for tile scanning 201 typename TilePrefixCallbackOpT::TempStorage prefix; // Smem needed for cooperative prefix callback 202 typename BlockDiscontinuityT::TempStorage discontinuity; // Smem needed for discontinuity detection 203 }; 204 205 // Smem needed for loading items 206 typename BlockLoadT::TempStorage load_items; 207 208 // Smem needed for loading values 209 typename BlockLoadFlags::TempStorage load_flags; 210 211 // Smem needed for compacting items (allows non POD items in this union) 212 Uninitialized<ItemExchangeT> raw_exchange; 213 }; 214 215 // Alias wrapper allowing storage to be unioned 216 struct TempStorage : Uninitialized<_TempStorage> {}; 217 218 219 //--------------------------------------------------------------------- 220 // Per-thread fields 221 //--------------------------------------------------------------------- 222 223 _TempStorage& temp_storage; ///< Reference to temp_storage 224 WrappedInputIteratorT d_in; ///< Input items 225 SelectedOutputIteratorT d_selected_out; ///< Unique output items 226 WrappedFlagsInputIteratorT d_flags_in; ///< Input selection flags (if applicable) 227 InequalityWrapper<EqualityOpT> inequality_op; ///< T inequality operator 228 SelectOpT select_op; ///< Selection operator 229 OffsetT num_items; ///< Total number of input items 230 231 232 //--------------------------------------------------------------------- 233 // Constructor 234 //--------------------------------------------------------------------- 235 236 // Constructor 237 __device__ __forceinline__ AgentSelectIfcub::AgentSelectIf238 AgentSelectIf( 239 TempStorage &temp_storage, ///< Reference to temp_storage 240 InputIteratorT d_in, ///< Input data 241 FlagsInputIteratorT d_flags_in, ///< Input selection flags (if applicable) 242 SelectedOutputIteratorT d_selected_out, ///< Output data 243 SelectOpT select_op, ///< Selection operator 244 EqualityOpT equality_op, ///< Equality operator 245 OffsetT num_items) ///< Total number of input items 246 : 247 temp_storage(temp_storage.Alias()), 248 d_in(d_in), 249 d_flags_in(d_flags_in), 250 d_selected_out(d_selected_out), 251 select_op(select_op), 252 inequality_op(equality_op), 253 num_items(num_items) 254 {} 255 256 257 //--------------------------------------------------------------------- 258 // Utility methods for initializing the selections 259 //--------------------------------------------------------------------- 260 261 /** 262 * Initialize selections (specialized for selection operator) 263 */ 264 template <bool IS_FIRST_TILE, bool IS_LAST_TILE> InitializeSelectionscub::AgentSelectIf265 __device__ __forceinline__ void InitializeSelections( 266 OffsetT /*tile_offset*/, 267 OffsetT num_tile_items, 268 OutputT (&items)[ITEMS_PER_THREAD], 269 OffsetT (&selection_flags)[ITEMS_PER_THREAD], 270 Int2Type<USE_SELECT_OP> /*select_method*/) 271 { 272 #pragma unroll 273 for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) 274 { 275 // Out-of-bounds items are selection_flags 276 selection_flags[ITEM] = 1; 277 278 if (!IS_LAST_TILE || (OffsetT(threadIdx.x * ITEMS_PER_THREAD) + ITEM < num_tile_items)) 279 selection_flags[ITEM] = select_op(items[ITEM]); 280 } 281 } 282 283 284 /** 285 * Initialize selections (specialized for valid flags) 286 */ 287 template <bool IS_FIRST_TILE, bool IS_LAST_TILE> InitializeSelectionscub::AgentSelectIf288 __device__ __forceinline__ void InitializeSelections( 289 OffsetT tile_offset, 290 OffsetT num_tile_items, 291 OutputT (&/*items*/)[ITEMS_PER_THREAD], 292 OffsetT (&selection_flags)[ITEMS_PER_THREAD], 293 Int2Type<USE_SELECT_FLAGS> /*select_method*/) 294 { 295 CTA_SYNC(); 296 297 FlagT flags[ITEMS_PER_THREAD]; 298 299 if (IS_LAST_TILE) 300 { 301 // Out-of-bounds items are selection_flags 302 BlockLoadFlags(temp_storage.load_flags).Load(d_flags_in + tile_offset, flags, num_tile_items, 1); 303 } 304 else 305 { 306 BlockLoadFlags(temp_storage.load_flags).Load(d_flags_in + tile_offset, flags); 307 } 308 309 // Convert flag type to selection_flags type 310 #pragma unroll 311 for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) 312 { 313 selection_flags[ITEM] = flags[ITEM]; 314 } 315 } 316 317 318 /** 319 * Initialize selections (specialized for discontinuity detection) 320 */ 321 template <bool IS_FIRST_TILE, bool IS_LAST_TILE> InitializeSelectionscub::AgentSelectIf322 __device__ __forceinline__ void InitializeSelections( 323 OffsetT tile_offset, 324 OffsetT num_tile_items, 325 OutputT (&items)[ITEMS_PER_THREAD], 326 OffsetT (&selection_flags)[ITEMS_PER_THREAD], 327 Int2Type<USE_DISCONTINUITY> /*select_method*/) 328 { 329 if (IS_FIRST_TILE) 330 { 331 CTA_SYNC(); 332 333 // Set head selection_flags. First tile sets the first flag for the first item 334 BlockDiscontinuityT(temp_storage.discontinuity).FlagHeads(selection_flags, items, inequality_op); 335 } 336 else 337 { 338 OutputT tile_predecessor; 339 if (threadIdx.x == 0) 340 tile_predecessor = d_in[tile_offset - 1]; 341 342 CTA_SYNC(); 343 344 BlockDiscontinuityT(temp_storage.discontinuity).FlagHeads(selection_flags, items, inequality_op, tile_predecessor); 345 } 346 347 // Set selection flags for out-of-bounds items 348 #pragma unroll 349 for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) 350 { 351 // Set selection_flags for out-of-bounds items 352 if ((IS_LAST_TILE) && (OffsetT(threadIdx.x * ITEMS_PER_THREAD) + ITEM >= num_tile_items)) 353 selection_flags[ITEM] = 1; 354 } 355 } 356 357 358 //--------------------------------------------------------------------- 359 // Scatter utility methods 360 //--------------------------------------------------------------------- 361 362 /** 363 * Scatter flagged items to output offsets (specialized for direct scattering) 364 */ 365 template <bool IS_LAST_TILE, bool IS_FIRST_TILE> ScatterDirectcub::AgentSelectIf366 __device__ __forceinline__ void ScatterDirect( 367 OutputT (&items)[ITEMS_PER_THREAD], 368 OffsetT (&selection_flags)[ITEMS_PER_THREAD], 369 OffsetT (&selection_indices)[ITEMS_PER_THREAD], 370 OffsetT num_selections) 371 { 372 // Scatter flagged items 373 #pragma unroll 374 for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) 375 { 376 if (selection_flags[ITEM]) 377 { 378 if ((!IS_LAST_TILE) || selection_indices[ITEM] < num_selections) 379 { 380 d_selected_out[selection_indices[ITEM]] = items[ITEM]; 381 } 382 } 383 } 384 } 385 386 387 /** 388 * Scatter flagged items to output offsets (specialized for two-phase scattering) 389 */ 390 template <bool IS_LAST_TILE, bool IS_FIRST_TILE> ScatterTwoPhasecub::AgentSelectIf391 __device__ __forceinline__ void ScatterTwoPhase( 392 OutputT (&items)[ITEMS_PER_THREAD], 393 OffsetT (&selection_flags)[ITEMS_PER_THREAD], 394 OffsetT (&selection_indices)[ITEMS_PER_THREAD], 395 int /*num_tile_items*/, ///< Number of valid items in this tile 396 int num_tile_selections, ///< Number of selections in this tile 397 OffsetT num_selections_prefix, ///< Total number of selections prior to this tile 398 OffsetT /*num_rejected_prefix*/, ///< Total number of rejections prior to this tile 399 Int2Type<false> /*is_keep_rejects*/) ///< Marker type indicating whether to keep rejected items in the second partition 400 { 401 CTA_SYNC(); 402 403 // Compact and scatter items 404 #pragma unroll 405 for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) 406 { 407 int local_scatter_offset = selection_indices[ITEM] - num_selections_prefix; 408 if (selection_flags[ITEM]) 409 { 410 temp_storage.raw_exchange.Alias()[local_scatter_offset] = items[ITEM]; 411 } 412 } 413 414 CTA_SYNC(); 415 416 for (int item = threadIdx.x; item < num_tile_selections; item += BLOCK_THREADS) 417 { 418 d_selected_out[num_selections_prefix + item] = temp_storage.raw_exchange.Alias()[item]; 419 } 420 } 421 422 423 /** 424 * Scatter flagged items to output offsets (specialized for two-phase scattering) 425 */ 426 template <bool IS_LAST_TILE, bool IS_FIRST_TILE> ScatterTwoPhasecub::AgentSelectIf427 __device__ __forceinline__ void ScatterTwoPhase( 428 OutputT (&items)[ITEMS_PER_THREAD], 429 OffsetT (&selection_flags)[ITEMS_PER_THREAD], 430 OffsetT (&selection_indices)[ITEMS_PER_THREAD], 431 int num_tile_items, ///< Number of valid items in this tile 432 int num_tile_selections, ///< Number of selections in this tile 433 OffsetT num_selections_prefix, ///< Total number of selections prior to this tile 434 OffsetT num_rejected_prefix, ///< Total number of rejections prior to this tile 435 Int2Type<true> /*is_keep_rejects*/) ///< Marker type indicating whether to keep rejected items in the second partition 436 { 437 CTA_SYNC(); 438 439 int tile_num_rejections = num_tile_items - num_tile_selections; 440 441 // Scatter items to shared memory (rejections first) 442 #pragma unroll 443 for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) 444 { 445 int item_idx = (threadIdx.x * ITEMS_PER_THREAD) + ITEM; 446 int local_selection_idx = selection_indices[ITEM] - num_selections_prefix; 447 int local_rejection_idx = item_idx - local_selection_idx; 448 int local_scatter_offset = (selection_flags[ITEM]) ? 449 tile_num_rejections + local_selection_idx : 450 local_rejection_idx; 451 452 temp_storage.raw_exchange.Alias()[local_scatter_offset] = items[ITEM]; 453 } 454 455 CTA_SYNC(); 456 457 // Gather items from shared memory and scatter to global 458 #pragma unroll 459 for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) 460 { 461 int item_idx = (ITEM * BLOCK_THREADS) + threadIdx.x; 462 int rejection_idx = item_idx; 463 int selection_idx = item_idx - tile_num_rejections; 464 OffsetT scatter_offset = (item_idx < tile_num_rejections) ? 465 num_items - num_rejected_prefix - rejection_idx - 1 : 466 num_selections_prefix + selection_idx; 467 468 OutputT item = temp_storage.raw_exchange.Alias()[item_idx]; 469 470 if (!IS_LAST_TILE || (item_idx < num_tile_items)) 471 { 472 d_selected_out[scatter_offset] = item; 473 } 474 } 475 } 476 477 478 /** 479 * Scatter flagged items 480 */ 481 template <bool IS_LAST_TILE, bool IS_FIRST_TILE> Scattercub::AgentSelectIf482 __device__ __forceinline__ void Scatter( 483 OutputT (&items)[ITEMS_PER_THREAD], 484 OffsetT (&selection_flags)[ITEMS_PER_THREAD], 485 OffsetT (&selection_indices)[ITEMS_PER_THREAD], 486 int num_tile_items, ///< Number of valid items in this tile 487 int num_tile_selections, ///< Number of selections in this tile 488 OffsetT num_selections_prefix, ///< Total number of selections prior to this tile 489 OffsetT num_rejected_prefix, ///< Total number of rejections prior to this tile 490 OffsetT num_selections) ///< Total number of selections including this tile 491 { 492 // Do a two-phase scatter if (a) keeping both partitions or (b) two-phase is enabled and the average number of selection_flags items per thread is greater than one 493 if (KEEP_REJECTS || (TWO_PHASE_SCATTER && (num_tile_selections > BLOCK_THREADS))) 494 { 495 ScatterTwoPhase<IS_LAST_TILE, IS_FIRST_TILE>( 496 items, 497 selection_flags, 498 selection_indices, 499 num_tile_items, 500 num_tile_selections, 501 num_selections_prefix, 502 num_rejected_prefix, 503 Int2Type<KEEP_REJECTS>()); 504 } 505 else 506 { 507 ScatterDirect<IS_LAST_TILE, IS_FIRST_TILE>( 508 items, 509 selection_flags, 510 selection_indices, 511 num_selections); 512 } 513 } 514 515 //--------------------------------------------------------------------- 516 // Cooperatively scan a device-wide sequence of tiles with other CTAs 517 //--------------------------------------------------------------------- 518 519 520 /** 521 * Process first tile of input (dynamic chained scan). Returns the running count of selections (including this tile) 522 */ 523 template <bool IS_LAST_TILE> ConsumeFirstTilecub::AgentSelectIf524 __device__ __forceinline__ OffsetT ConsumeFirstTile( 525 int num_tile_items, ///< Number of input items comprising this tile 526 OffsetT tile_offset, ///< Tile offset 527 ScanTileStateT& tile_state) ///< Global tile state descriptor 528 { 529 OutputT items[ITEMS_PER_THREAD]; 530 OffsetT selection_flags[ITEMS_PER_THREAD]; 531 OffsetT selection_indices[ITEMS_PER_THREAD]; 532 533 // Load items 534 if (IS_LAST_TILE) 535 BlockLoadT(temp_storage.load_items).Load(d_in + tile_offset, items, num_tile_items); 536 else 537 BlockLoadT(temp_storage.load_items).Load(d_in + tile_offset, items); 538 539 // Initialize selection_flags 540 InitializeSelections<true, IS_LAST_TILE>( 541 tile_offset, 542 num_tile_items, 543 items, 544 selection_flags, 545 Int2Type<SELECT_METHOD>()); 546 547 CTA_SYNC(); 548 549 // Exclusive scan of selection_flags 550 OffsetT num_tile_selections; 551 BlockScanT(temp_storage.scan).ExclusiveSum(selection_flags, selection_indices, num_tile_selections); 552 553 if (threadIdx.x == 0) 554 { 555 // Update tile status if this is not the last tile 556 if (!IS_LAST_TILE) 557 tile_state.SetInclusive(0, num_tile_selections); 558 } 559 560 // Discount any out-of-bounds selections 561 if (IS_LAST_TILE) 562 num_tile_selections -= (TILE_ITEMS - num_tile_items); 563 564 // Scatter flagged items 565 Scatter<IS_LAST_TILE, true>( 566 items, 567 selection_flags, 568 selection_indices, 569 num_tile_items, 570 num_tile_selections, 571 0, 572 0, 573 num_tile_selections); 574 575 return num_tile_selections; 576 } 577 578 579 /** 580 * Process subsequent tile of input (dynamic chained scan). Returns the running count of selections (including this tile) 581 */ 582 template <bool IS_LAST_TILE> ConsumeSubsequentTilecub::AgentSelectIf583 __device__ __forceinline__ OffsetT ConsumeSubsequentTile( 584 int num_tile_items, ///< Number of input items comprising this tile 585 int tile_idx, ///< Tile index 586 OffsetT tile_offset, ///< Tile offset 587 ScanTileStateT& tile_state) ///< Global tile state descriptor 588 { 589 OutputT items[ITEMS_PER_THREAD]; 590 OffsetT selection_flags[ITEMS_PER_THREAD]; 591 OffsetT selection_indices[ITEMS_PER_THREAD]; 592 593 // Load items 594 if (IS_LAST_TILE) 595 BlockLoadT(temp_storage.load_items).Load(d_in + tile_offset, items, num_tile_items); 596 else 597 BlockLoadT(temp_storage.load_items).Load(d_in + tile_offset, items); 598 599 // Initialize selection_flags 600 InitializeSelections<false, IS_LAST_TILE>( 601 tile_offset, 602 num_tile_items, 603 items, 604 selection_flags, 605 Int2Type<SELECT_METHOD>()); 606 607 CTA_SYNC(); 608 609 // Exclusive scan of values and selection_flags 610 TilePrefixCallbackOpT prefix_op(tile_state, temp_storage.prefix, cub::Sum(), tile_idx); 611 BlockScanT(temp_storage.scan).ExclusiveSum(selection_flags, selection_indices, prefix_op); 612 613 OffsetT num_tile_selections = prefix_op.GetBlockAggregate(); 614 OffsetT num_selections = prefix_op.GetInclusivePrefix(); 615 OffsetT num_selections_prefix = prefix_op.GetExclusivePrefix(); 616 OffsetT num_rejected_prefix = (tile_idx * TILE_ITEMS) - num_selections_prefix; 617 618 // Discount any out-of-bounds selections 619 if (IS_LAST_TILE) 620 { 621 int num_discount = TILE_ITEMS - num_tile_items; 622 num_selections -= num_discount; 623 num_tile_selections -= num_discount; 624 } 625 626 // Scatter flagged items 627 Scatter<IS_LAST_TILE, false>( 628 items, 629 selection_flags, 630 selection_indices, 631 num_tile_items, 632 num_tile_selections, 633 num_selections_prefix, 634 num_rejected_prefix, 635 num_selections); 636 637 return num_selections; 638 } 639 640 641 /** 642 * Process a tile of input 643 */ 644 template <bool IS_LAST_TILE> ConsumeTilecub::AgentSelectIf645 __device__ __forceinline__ OffsetT ConsumeTile( 646 int num_tile_items, ///< Number of input items comprising this tile 647 int tile_idx, ///< Tile index 648 OffsetT tile_offset, ///< Tile offset 649 ScanTileStateT& tile_state) ///< Global tile state descriptor 650 { 651 OffsetT num_selections; 652 if (tile_idx == 0) 653 { 654 num_selections = ConsumeFirstTile<IS_LAST_TILE>(num_tile_items, tile_offset, tile_state); 655 } 656 else 657 { 658 num_selections = ConsumeSubsequentTile<IS_LAST_TILE>(num_tile_items, tile_idx, tile_offset, tile_state); 659 } 660 661 return num_selections; 662 } 663 664 665 /** 666 * Scan tiles of items as part of a dynamic chained scan 667 */ 668 template <typename NumSelectedIteratorT> ///< Output iterator type for recording number of items selection_flags ConsumeRangecub::AgentSelectIf669 __device__ __forceinline__ void ConsumeRange( 670 int num_tiles, ///< Total number of input tiles 671 ScanTileStateT& tile_state, ///< Global tile state descriptor 672 NumSelectedIteratorT d_num_selected_out) ///< Output total number selection_flags 673 { 674 // Blocks are launched in increasing order, so just assign one tile per block 675 int tile_idx = (blockIdx.x * gridDim.y) + blockIdx.y; // Current tile index 676 OffsetT tile_offset = tile_idx * TILE_ITEMS; // Global offset for the current tile 677 678 if (tile_idx < num_tiles - 1) 679 { 680 // Not the last tile (full) 681 ConsumeTile<false>(TILE_ITEMS, tile_idx, tile_offset, tile_state); 682 } 683 else 684 { 685 // The last tile (possibly partially-full) 686 OffsetT num_remaining = num_items - tile_offset; 687 OffsetT num_selections = ConsumeTile<true>(num_remaining, tile_idx, tile_offset, tile_state); 688 689 if (threadIdx.x == 0) 690 { 691 // Output the total number of items selection_flags 692 *d_num_selected_out = num_selections; 693 } 694 } 695 } 696 697 }; 698 699 700 701 } // CUB namespace 702 CUB_NS_POSTFIX // Optional outer namespace(s) 703 704