1 /****************************************************************************** 2 * Copyright (c) 2011, Duane Merrill. All rights reserved. 3 * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. 4 * 5 * Redistribution and use in source and binary forms, with or without 6 * modification, are permitted provided that the following conditions are met: 7 * * Redistributions of source code must retain the above copyright 8 * notice, this list of conditions and the following disclaimer. 9 * * Redistributions in binary form must reproduce the above copyright 10 * notice, this list of conditions and the following disclaimer in the 11 * documentation and/or other materials provided with the distribution. 12 * * Neither the name of the NVIDIA CORPORATION nor the 13 * names of its contributors may be used to endorse or promote products 14 * derived from this software without specific prior written permission. 15 * 16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 17 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 18 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 20 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 21 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 22 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 23 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 25 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 * 27 ******************************************************************************/ 28 29 /** 30 * \file 31 * cub::AgentReduceByKey implements a stateful abstraction of CUDA thread blocks for participating in device-wide reduce-value-by-key. 32 */ 33 34 #pragma once 35 36 #include <iterator> 37 38 #include "single_pass_scan_operators.cuh" 39 #include "../block/block_load.cuh" 40 #include "../block/block_store.cuh" 41 #include "../block/block_scan.cuh" 42 #include "../block/block_discontinuity.cuh" 43 #include "../iterator/cache_modified_input_iterator.cuh" 44 #include "../iterator/constant_input_iterator.cuh" 45 #include "../util_namespace.cuh" 46 47 /// Optional outer namespace(s) 48 CUB_NS_PREFIX 49 50 /// CUB namespace 51 namespace cub { 52 53 54 /****************************************************************************** 55 * Tuning policy types 56 ******************************************************************************/ 57 58 /** 59 * Parameterizable tuning policy type for AgentReduceByKey 60 */ 61 template < 62 int _BLOCK_THREADS, ///< Threads per thread block 63 int _ITEMS_PER_THREAD, ///< Items per thread (per tile of input) 64 BlockLoadAlgorithm _LOAD_ALGORITHM, ///< The BlockLoad algorithm to use 65 CacheLoadModifier _LOAD_MODIFIER, ///< Cache load modifier for reading input elements 66 BlockScanAlgorithm _SCAN_ALGORITHM> ///< The BlockScan algorithm to use 67 struct AgentReduceByKeyPolicy 68 { 69 enum 70 { 71 BLOCK_THREADS = _BLOCK_THREADS, ///< Threads per thread block 72 ITEMS_PER_THREAD = _ITEMS_PER_THREAD, ///< Items per thread (per tile of input) 73 }; 74 75 static const BlockLoadAlgorithm LOAD_ALGORITHM = _LOAD_ALGORITHM; ///< The BlockLoad algorithm to use 76 static const CacheLoadModifier LOAD_MODIFIER = _LOAD_MODIFIER; ///< Cache load modifier for reading input elements 77 static const BlockScanAlgorithm SCAN_ALGORITHM = _SCAN_ALGORITHM; ///< The BlockScan algorithm to use 78 }; 79 80 81 /****************************************************************************** 82 * Thread block abstractions 83 ******************************************************************************/ 84 85 /** 86 * \brief AgentReduceByKey implements a stateful abstraction of CUDA thread blocks for participating in device-wide reduce-value-by-key 87 */ 88 template < 89 typename AgentReduceByKeyPolicyT, ///< Parameterized AgentReduceByKeyPolicy tuning policy type 90 typename KeysInputIteratorT, ///< Random-access input iterator type for keys 91 typename UniqueOutputIteratorT, ///< Random-access output iterator type for keys 92 typename ValuesInputIteratorT, ///< Random-access input iterator type for values 93 typename AggregatesOutputIteratorT, ///< Random-access output iterator type for values 94 typename NumRunsOutputIteratorT, ///< Output iterator type for recording number of items selected 95 typename EqualityOpT, ///< KeyT equality operator type 96 typename ReductionOpT, ///< ValueT reduction operator type 97 typename OffsetT> ///< Signed integer type for global offsets 98 struct AgentReduceByKey 99 { 100 //--------------------------------------------------------------------- 101 // Types and constants 102 //--------------------------------------------------------------------- 103 104 // The input keys type 105 typedef typename std::iterator_traits<KeysInputIteratorT>::value_type KeyInputT; 106 107 // The output keys type 108 typedef typename If<(Equals<typename std::iterator_traits<UniqueOutputIteratorT>::value_type, void>::VALUE), // KeyOutputT = (if output iterator's value type is void) ? 109 typename std::iterator_traits<KeysInputIteratorT>::value_type, // ... then the input iterator's value type, 110 typename std::iterator_traits<UniqueOutputIteratorT>::value_type>::Type KeyOutputT; // ... else the output iterator's value type 111 112 // The input values type 113 typedef typename std::iterator_traits<ValuesInputIteratorT>::value_type ValueInputT; 114 115 // The output values type 116 typedef typename If<(Equals<typename std::iterator_traits<AggregatesOutputIteratorT>::value_type, void>::VALUE), // ValueOutputT = (if output iterator's value type is void) ? 117 typename std::iterator_traits<ValuesInputIteratorT>::value_type, // ... then the input iterator's value type, 118 typename std::iterator_traits<AggregatesOutputIteratorT>::value_type>::Type ValueOutputT; // ... else the output iterator's value type 119 120 // Tuple type for scanning (pairs accumulated segment-value with segment-index) 121 typedef KeyValuePair<OffsetT, ValueOutputT> OffsetValuePairT; 122 123 // Tuple type for pairing keys and values 124 typedef KeyValuePair<KeyOutputT, ValueOutputT> KeyValuePairT; 125 126 // Tile status descriptor interface type 127 typedef ReduceByKeyScanTileState<ValueOutputT, OffsetT> ScanTileStateT; 128 129 // Guarded inequality functor 130 template <typename _EqualityOpT> 131 struct GuardedInequalityWrapper 132 { 133 _EqualityOpT op; ///< Wrapped equality operator 134 int num_remaining; ///< Items remaining 135 136 /// Constructor 137 __host__ __device__ __forceinline__ GuardedInequalityWrappercub::AgentReduceByKey::GuardedInequalityWrapper138 GuardedInequalityWrapper(_EqualityOpT op, int num_remaining) : op(op), num_remaining(num_remaining) {} 139 140 /// Boolean inequality operator, returns <tt>(a != b)</tt> 141 template <typename T> operator ()cub::AgentReduceByKey::GuardedInequalityWrapper142 __host__ __device__ __forceinline__ bool operator()(const T &a, const T &b, int idx) const 143 { 144 if (idx < num_remaining) 145 return !op(a, b); // In bounds 146 147 // Return true if first out-of-bounds item, false otherwise 148 return (idx == num_remaining); 149 } 150 }; 151 152 153 // Constants 154 enum 155 { 156 BLOCK_THREADS = AgentReduceByKeyPolicyT::BLOCK_THREADS, 157 ITEMS_PER_THREAD = AgentReduceByKeyPolicyT::ITEMS_PER_THREAD, 158 TILE_ITEMS = BLOCK_THREADS * ITEMS_PER_THREAD, 159 TWO_PHASE_SCATTER = (ITEMS_PER_THREAD > 1), 160 161 // Whether or not the scan operation has a zero-valued identity value (true if we're performing addition on a primitive type) 162 HAS_IDENTITY_ZERO = (Equals<ReductionOpT, cub::Sum>::VALUE) && (Traits<ValueOutputT>::PRIMITIVE), 163 }; 164 165 // Cache-modified Input iterator wrapper type (for applying cache modifier) for keys 166 typedef typename If<IsPointer<KeysInputIteratorT>::VALUE, 167 CacheModifiedInputIterator<AgentReduceByKeyPolicyT::LOAD_MODIFIER, KeyInputT, OffsetT>, // Wrap the native input pointer with CacheModifiedValuesInputIterator 168 KeysInputIteratorT>::Type // Directly use the supplied input iterator type 169 WrappedKeysInputIteratorT; 170 171 // Cache-modified Input iterator wrapper type (for applying cache modifier) for values 172 typedef typename If<IsPointer<ValuesInputIteratorT>::VALUE, 173 CacheModifiedInputIterator<AgentReduceByKeyPolicyT::LOAD_MODIFIER, ValueInputT, OffsetT>, // Wrap the native input pointer with CacheModifiedValuesInputIterator 174 ValuesInputIteratorT>::Type // Directly use the supplied input iterator type 175 WrappedValuesInputIteratorT; 176 177 // Cache-modified Input iterator wrapper type (for applying cache modifier) for fixup values 178 typedef typename If<IsPointer<AggregatesOutputIteratorT>::VALUE, 179 CacheModifiedInputIterator<AgentReduceByKeyPolicyT::LOAD_MODIFIER, ValueInputT, OffsetT>, // Wrap the native input pointer with CacheModifiedValuesInputIterator 180 AggregatesOutputIteratorT>::Type // Directly use the supplied input iterator type 181 WrappedFixupInputIteratorT; 182 183 // Reduce-value-by-segment scan operator 184 typedef ReduceBySegmentOp<ReductionOpT> ReduceBySegmentOpT; 185 186 // Parameterized BlockLoad type for keys 187 typedef BlockLoad< 188 KeyOutputT, 189 BLOCK_THREADS, 190 ITEMS_PER_THREAD, 191 AgentReduceByKeyPolicyT::LOAD_ALGORITHM> 192 BlockLoadKeysT; 193 194 // Parameterized BlockLoad type for values 195 typedef BlockLoad< 196 ValueOutputT, 197 BLOCK_THREADS, 198 ITEMS_PER_THREAD, 199 AgentReduceByKeyPolicyT::LOAD_ALGORITHM> 200 BlockLoadValuesT; 201 202 // Parameterized BlockDiscontinuity type for keys 203 typedef BlockDiscontinuity< 204 KeyOutputT, 205 BLOCK_THREADS> 206 BlockDiscontinuityKeys; 207 208 // Parameterized BlockScan type 209 typedef BlockScan< 210 OffsetValuePairT, 211 BLOCK_THREADS, 212 AgentReduceByKeyPolicyT::SCAN_ALGORITHM> 213 BlockScanT; 214 215 // Callback type for obtaining tile prefix during block scan 216 typedef TilePrefixCallbackOp< 217 OffsetValuePairT, 218 ReduceBySegmentOpT, 219 ScanTileStateT> 220 TilePrefixCallbackOpT; 221 222 // Key and value exchange types 223 typedef KeyOutputT KeyExchangeT[TILE_ITEMS + 1]; 224 typedef ValueOutputT ValueExchangeT[TILE_ITEMS + 1]; 225 226 // Shared memory type for this thread block 227 union _TempStorage 228 { 229 struct 230 { 231 typename BlockScanT::TempStorage scan; // Smem needed for tile scanning 232 typename TilePrefixCallbackOpT::TempStorage prefix; // Smem needed for cooperative prefix callback 233 typename BlockDiscontinuityKeys::TempStorage discontinuity; // Smem needed for discontinuity detection 234 }; 235 236 // Smem needed for loading keys 237 typename BlockLoadKeysT::TempStorage load_keys; 238 239 // Smem needed for loading values 240 typename BlockLoadValuesT::TempStorage load_values; 241 242 // Smem needed for compacting key value pairs(allows non POD items in this union) 243 Uninitialized<KeyValuePairT[TILE_ITEMS + 1]> raw_exchange; 244 }; 245 246 // Alias wrapper allowing storage to be unioned 247 struct TempStorage : Uninitialized<_TempStorage> {}; 248 249 250 //--------------------------------------------------------------------- 251 // Per-thread fields 252 //--------------------------------------------------------------------- 253 254 _TempStorage& temp_storage; ///< Reference to temp_storage 255 WrappedKeysInputIteratorT d_keys_in; ///< Input keys 256 UniqueOutputIteratorT d_unique_out; ///< Unique output keys 257 WrappedValuesInputIteratorT d_values_in; ///< Input values 258 AggregatesOutputIteratorT d_aggregates_out; ///< Output value aggregates 259 NumRunsOutputIteratorT d_num_runs_out; ///< Output pointer for total number of segments identified 260 EqualityOpT equality_op; ///< KeyT equality operator 261 ReductionOpT reduction_op; ///< Reduction operator 262 ReduceBySegmentOpT scan_op; ///< Reduce-by-segment scan operator 263 264 265 //--------------------------------------------------------------------- 266 // Constructor 267 //--------------------------------------------------------------------- 268 269 // Constructor 270 __device__ __forceinline__ AgentReduceByKeycub::AgentReduceByKey271 AgentReduceByKey( 272 TempStorage& temp_storage, ///< Reference to temp_storage 273 KeysInputIteratorT d_keys_in, ///< Input keys 274 UniqueOutputIteratorT d_unique_out, ///< Unique output keys 275 ValuesInputIteratorT d_values_in, ///< Input values 276 AggregatesOutputIteratorT d_aggregates_out, ///< Output value aggregates 277 NumRunsOutputIteratorT d_num_runs_out, ///< Output pointer for total number of segments identified 278 EqualityOpT equality_op, ///< KeyT equality operator 279 ReductionOpT reduction_op) ///< ValueT reduction operator 280 : 281 temp_storage(temp_storage.Alias()), 282 d_keys_in(d_keys_in), 283 d_unique_out(d_unique_out), 284 d_values_in(d_values_in), 285 d_aggregates_out(d_aggregates_out), 286 d_num_runs_out(d_num_runs_out), 287 equality_op(equality_op), 288 reduction_op(reduction_op), 289 scan_op(reduction_op) 290 {} 291 292 293 //--------------------------------------------------------------------- 294 // Scatter utility methods 295 //--------------------------------------------------------------------- 296 297 /** 298 * Directly scatter flagged items to output offsets 299 */ ScatterDirectcub::AgentReduceByKey300 __device__ __forceinline__ void ScatterDirect( 301 KeyValuePairT (&scatter_items)[ITEMS_PER_THREAD], 302 OffsetT (&segment_flags)[ITEMS_PER_THREAD], 303 OffsetT (&segment_indices)[ITEMS_PER_THREAD]) 304 { 305 // Scatter flagged keys and values 306 #pragma unroll 307 for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) 308 { 309 if (segment_flags[ITEM]) 310 { 311 d_unique_out[segment_indices[ITEM]] = scatter_items[ITEM].key; 312 d_aggregates_out[segment_indices[ITEM]] = scatter_items[ITEM].value; 313 } 314 } 315 } 316 317 318 /** 319 * 2-phase scatter flagged items to output offsets 320 * 321 * The exclusive scan causes each head flag to be paired with the previous 322 * value aggregate: the scatter offsets must be decremented for value aggregates 323 */ ScatterTwoPhasecub::AgentReduceByKey324 __device__ __forceinline__ void ScatterTwoPhase( 325 KeyValuePairT (&scatter_items)[ITEMS_PER_THREAD], 326 OffsetT (&segment_flags)[ITEMS_PER_THREAD], 327 OffsetT (&segment_indices)[ITEMS_PER_THREAD], 328 OffsetT num_tile_segments, 329 OffsetT num_tile_segments_prefix) 330 { 331 CTA_SYNC(); 332 333 // Compact and scatter pairs 334 #pragma unroll 335 for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) 336 { 337 if (segment_flags[ITEM]) 338 { 339 temp_storage.raw_exchange.Alias()[segment_indices[ITEM] - num_tile_segments_prefix] = scatter_items[ITEM]; 340 } 341 } 342 343 CTA_SYNC(); 344 345 for (int item = threadIdx.x; item < num_tile_segments; item += BLOCK_THREADS) 346 { 347 KeyValuePairT pair = temp_storage.raw_exchange.Alias()[item]; 348 d_unique_out[num_tile_segments_prefix + item] = pair.key; 349 d_aggregates_out[num_tile_segments_prefix + item] = pair.value; 350 } 351 } 352 353 354 /** 355 * Scatter flagged items 356 */ Scattercub::AgentReduceByKey357 __device__ __forceinline__ void Scatter( 358 KeyValuePairT (&scatter_items)[ITEMS_PER_THREAD], 359 OffsetT (&segment_flags)[ITEMS_PER_THREAD], 360 OffsetT (&segment_indices)[ITEMS_PER_THREAD], 361 OffsetT num_tile_segments, 362 OffsetT num_tile_segments_prefix) 363 { 364 // Do a one-phase scatter if (a) two-phase is disabled or (b) the average number of selected items per thread is less than one 365 if (TWO_PHASE_SCATTER && (num_tile_segments > BLOCK_THREADS)) 366 { 367 ScatterTwoPhase( 368 scatter_items, 369 segment_flags, 370 segment_indices, 371 num_tile_segments, 372 num_tile_segments_prefix); 373 } 374 else 375 { 376 ScatterDirect( 377 scatter_items, 378 segment_flags, 379 segment_indices); 380 } 381 } 382 383 384 //--------------------------------------------------------------------- 385 // Cooperatively scan a device-wide sequence of tiles with other CTAs 386 //--------------------------------------------------------------------- 387 388 /** 389 * Process a tile of input (dynamic chained scan) 390 */ 391 template <bool IS_LAST_TILE> ///< Whether the current tile is the last tile ConsumeTilecub::AgentReduceByKey392 __device__ __forceinline__ void ConsumeTile( 393 OffsetT num_remaining, ///< Number of global input items remaining (including this tile) 394 int tile_idx, ///< Tile index 395 OffsetT tile_offset, ///< Tile offset 396 ScanTileStateT& tile_state) ///< Global tile state descriptor 397 { 398 KeyOutputT keys[ITEMS_PER_THREAD]; // Tile keys 399 KeyOutputT prev_keys[ITEMS_PER_THREAD]; // Tile keys shuffled up 400 ValueOutputT values[ITEMS_PER_THREAD]; // Tile values 401 OffsetT head_flags[ITEMS_PER_THREAD]; // Segment head flags 402 OffsetT segment_indices[ITEMS_PER_THREAD]; // Segment indices 403 OffsetValuePairT scan_items[ITEMS_PER_THREAD]; // Zipped values and segment flags|indices 404 KeyValuePairT scatter_items[ITEMS_PER_THREAD]; // Zipped key value pairs for scattering 405 406 // Load keys 407 if (IS_LAST_TILE) 408 BlockLoadKeysT(temp_storage.load_keys).Load(d_keys_in + tile_offset, keys, num_remaining); 409 else 410 BlockLoadKeysT(temp_storage.load_keys).Load(d_keys_in + tile_offset, keys); 411 412 // Load tile predecessor key in first thread 413 KeyOutputT tile_predecessor; 414 if (threadIdx.x == 0) 415 { 416 tile_predecessor = (tile_idx == 0) ? 417 keys[0] : // First tile gets repeat of first item (thus first item will not be flagged as a head) 418 d_keys_in[tile_offset - 1]; // Subsequent tiles get last key from previous tile 419 } 420 421 CTA_SYNC(); 422 423 // Load values 424 if (IS_LAST_TILE) 425 BlockLoadValuesT(temp_storage.load_values).Load(d_values_in + tile_offset, values, num_remaining); 426 else 427 BlockLoadValuesT(temp_storage.load_values).Load(d_values_in + tile_offset, values); 428 429 CTA_SYNC(); 430 431 // Initialize head-flags and shuffle up the previous keys 432 if (IS_LAST_TILE) 433 { 434 // Use custom flag operator to additionally flag the first out-of-bounds item 435 GuardedInequalityWrapper<EqualityOpT> flag_op(equality_op, num_remaining); 436 BlockDiscontinuityKeys(temp_storage.discontinuity).FlagHeads( 437 head_flags, keys, prev_keys, flag_op, tile_predecessor); 438 } 439 else 440 { 441 InequalityWrapper<EqualityOpT> flag_op(equality_op); 442 BlockDiscontinuityKeys(temp_storage.discontinuity).FlagHeads( 443 head_flags, keys, prev_keys, flag_op, tile_predecessor); 444 } 445 446 // Zip values and head flags 447 #pragma unroll 448 for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) 449 { 450 scan_items[ITEM].value = values[ITEM]; 451 scan_items[ITEM].key = head_flags[ITEM]; 452 } 453 454 // Perform exclusive tile scan 455 OffsetValuePairT block_aggregate; // Inclusive block-wide scan aggregate 456 OffsetT num_segments_prefix; // Number of segments prior to this tile 457 OffsetValuePairT total_aggregate; // The tile prefix folded with block_aggregate 458 if (tile_idx == 0) 459 { 460 // Scan first tile 461 BlockScanT(temp_storage.scan).ExclusiveScan(scan_items, scan_items, scan_op, block_aggregate); 462 num_segments_prefix = 0; 463 total_aggregate = block_aggregate; 464 465 // Update tile status if there are successor tiles 466 if ((!IS_LAST_TILE) && (threadIdx.x == 0)) 467 tile_state.SetInclusive(0, block_aggregate); 468 } 469 else 470 { 471 // Scan non-first tile 472 TilePrefixCallbackOpT prefix_op(tile_state, temp_storage.prefix, scan_op, tile_idx); 473 BlockScanT(temp_storage.scan).ExclusiveScan(scan_items, scan_items, scan_op, prefix_op); 474 475 block_aggregate = prefix_op.GetBlockAggregate(); 476 num_segments_prefix = prefix_op.GetExclusivePrefix().key; 477 total_aggregate = prefix_op.GetInclusivePrefix(); 478 } 479 480 // Rezip scatter items and segment indices 481 #pragma unroll 482 for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) 483 { 484 scatter_items[ITEM].key = prev_keys[ITEM]; 485 scatter_items[ITEM].value = scan_items[ITEM].value; 486 segment_indices[ITEM] = scan_items[ITEM].key; 487 } 488 489 // At this point, each flagged segment head has: 490 // - The key for the previous segment 491 // - The reduced value from the previous segment 492 // - The segment index for the reduced value 493 494 // Scatter flagged keys and values 495 OffsetT num_tile_segments = block_aggregate.key; 496 Scatter(scatter_items, head_flags, segment_indices, num_tile_segments, num_segments_prefix); 497 498 // Last thread in last tile will output final count (and last pair, if necessary) 499 if ((IS_LAST_TILE) && (threadIdx.x == BLOCK_THREADS - 1)) 500 { 501 OffsetT num_segments = num_segments_prefix + num_tile_segments; 502 503 // If the last tile is a whole tile, output the final_value 504 if (num_remaining == TILE_ITEMS) 505 { 506 d_unique_out[num_segments] = keys[ITEMS_PER_THREAD - 1]; 507 d_aggregates_out[num_segments] = total_aggregate.value; 508 num_segments++; 509 } 510 511 // Output the total number of items selected 512 *d_num_runs_out = num_segments; 513 } 514 } 515 516 517 /** 518 * Scan tiles of items as part of a dynamic chained scan 519 */ ConsumeRangecub::AgentReduceByKey520 __device__ __forceinline__ void ConsumeRange( 521 int num_items, ///< Total number of input items 522 ScanTileStateT& tile_state, ///< Global tile state descriptor 523 int start_tile) ///< The starting tile for the current grid 524 { 525 // Blocks are launched in increasing order, so just assign one tile per block 526 int tile_idx = start_tile + blockIdx.x; // Current tile index 527 OffsetT tile_offset = OffsetT(TILE_ITEMS) * tile_idx; // Global offset for the current tile 528 OffsetT num_remaining = num_items - tile_offset; // Remaining items (including this tile) 529 530 if (num_remaining > TILE_ITEMS) 531 { 532 // Not last tile 533 ConsumeTile<false>(num_remaining, tile_idx, tile_offset, tile_state); 534 } 535 else if (num_remaining > 0) 536 { 537 // Last tile 538 ConsumeTile<true>(num_remaining, tile_idx, tile_offset, tile_state); 539 } 540 } 541 542 }; 543 544 545 } // CUB namespace 546 CUB_NS_POSTFIX // Optional outer namespace(s) 547 548