1 /****************************************************************************** 2 * Copyright (c) 2011, Duane Merrill. All rights reserved. 3 * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. 4 * 5 * Redistribution and use in source and binary forms, with or without 6 * modification, are permitted provided that the following conditions are met: 7 * * Redistributions of source code must retain the above copyright 8 * notice, this list of conditions and the following disclaimer. 9 * * Redistributions in binary form must reproduce the above copyright 10 * notice, this list of conditions and the following disclaimer in the 11 * documentation and/or other materials provided with the distribution. 12 * * Neither the name of the NVIDIA CORPORATION nor the 13 * names of its contributors may be used to endorse or promote products 14 * derived from this software without specific prior written permission. 15 * 16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 17 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 18 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 20 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 21 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 22 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 23 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 25 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 * 27 ******************************************************************************/ 28 29 /** 30 * \file 31 * Callback operator types for supplying BlockScan prefixes 32 */ 33 34 #pragma once 35 36 #include <iterator> 37 38 #include "../thread/thread_load.cuh" 39 #include "../thread/thread_store.cuh" 40 #include "../warp/warp_reduce.cuh" 41 #include "../util_arch.cuh" 42 #include "../util_device.cuh" 43 #include "../util_namespace.cuh" 44 45 /// Optional outer namespace(s) 46 CUB_NS_PREFIX 47 48 /// CUB namespace 49 namespace cub { 50 51 52 /****************************************************************************** 53 * Prefix functor type for maintaining a running prefix while scanning a 54 * region independent of other thread blocks 55 ******************************************************************************/ 56 57 /** 58 * Stateful callback operator type for supplying BlockScan prefixes. 59 * Maintains a running prefix that can be applied to consecutive 60 * BlockScan operations. 61 */ 62 template < 63 typename T, ///< BlockScan value type 64 typename ScanOpT> ///< Wrapped scan operator type 65 struct BlockScanRunningPrefixOp 66 { 67 ScanOpT op; ///< Wrapped scan operator 68 T running_total; ///< Running block-wide prefix 69 70 /// Constructor BlockScanRunningPrefixOpcub::BlockScanRunningPrefixOp71 __device__ __forceinline__ BlockScanRunningPrefixOp(ScanOpT op) 72 : 73 op(op) 74 {} 75 76 /// Constructor BlockScanRunningPrefixOpcub::BlockScanRunningPrefixOp77 __device__ __forceinline__ BlockScanRunningPrefixOp( 78 T starting_prefix, 79 ScanOpT op) 80 : 81 op(op), 82 running_total(starting_prefix) 83 {} 84 85 /** 86 * Prefix callback operator. Returns the block-wide running_total in thread-0. 87 */ operator ()cub::BlockScanRunningPrefixOp88 __device__ __forceinline__ T operator()( 89 const T &block_aggregate) ///< The aggregate sum of the BlockScan inputs 90 { 91 T retval = running_total; 92 running_total = op(running_total, block_aggregate); 93 return retval; 94 } 95 }; 96 97 98 /****************************************************************************** 99 * Generic tile status interface types for block-cooperative scans 100 ******************************************************************************/ 101 102 /** 103 * Enumerations of tile status 104 */ 105 enum ScanTileStatus 106 { 107 SCAN_TILE_OOB, // Out-of-bounds (e.g., padding) 108 SCAN_TILE_INVALID = 99, // Not yet processed 109 SCAN_TILE_PARTIAL, // Tile aggregate is available 110 SCAN_TILE_INCLUSIVE, // Inclusive tile prefix is available 111 }; 112 113 114 /** 115 * Tile status interface. 116 */ 117 template < 118 typename T, 119 bool SINGLE_WORD = Traits<T>::PRIMITIVE> 120 struct ScanTileState; 121 122 123 /** 124 * Tile status interface specialized for scan status and value types 125 * that can be combined into one machine word that can be 126 * read/written coherently in a single access. 127 */ 128 template <typename T> 129 struct ScanTileState<T, true> 130 { 131 // Status word type 132 typedef typename If<(sizeof(T) == 8), 133 long long, 134 typename If<(sizeof(T) == 4), 135 int, 136 typename If<(sizeof(T) == 2), 137 short, 138 char>::Type>::Type>::Type StatusWord; 139 140 141 // Unit word type 142 typedef typename If<(sizeof(T) == 8), 143 longlong2, 144 typename If<(sizeof(T) == 4), 145 int2, 146 typename If<(sizeof(T) == 2), 147 int, 148 uchar2>::Type>::Type>::Type TxnWord; 149 150 151 // Device word type 152 struct TileDescriptor 153 { 154 StatusWord status; 155 T value; 156 }; 157 158 159 // Constants 160 enum 161 { 162 TILE_STATUS_PADDING = CUB_PTX_WARP_THREADS, 163 }; 164 165 166 // Device storage 167 TxnWord *d_tile_descriptors; 168 169 /// Constructor 170 __host__ __device__ __forceinline__ ScanTileStatecub::ScanTileState171 ScanTileState() 172 : 173 d_tile_descriptors(NULL) 174 {} 175 176 177 /// Initializer 178 __host__ __device__ __forceinline__ Initcub::ScanTileState179 cudaError_t Init( 180 int /*num_tiles*/, ///< [in] Number of tiles 181 void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. 182 size_t /*temp_storage_bytes*/) ///< [in] Size in bytes of \t d_temp_storage allocation 183 { 184 d_tile_descriptors = reinterpret_cast<TxnWord*>(d_temp_storage); 185 return cudaSuccess; 186 } 187 188 189 /** 190 * Compute device memory needed for tile status 191 */ 192 __host__ __device__ __forceinline__ AllocationSizecub::ScanTileState193 static cudaError_t AllocationSize( 194 int num_tiles, ///< [in] Number of tiles 195 size_t &temp_storage_bytes) ///< [out] Size in bytes of \t d_temp_storage allocation 196 { 197 temp_storage_bytes = (num_tiles + TILE_STATUS_PADDING) * sizeof(TileDescriptor); // bytes needed for tile status descriptors 198 return cudaSuccess; 199 } 200 201 202 /** 203 * Initialize (from device) 204 */ InitializeStatuscub::ScanTileState205 __device__ __forceinline__ void InitializeStatus(int num_tiles) 206 { 207 int tile_idx = (blockIdx.x * blockDim.x) + threadIdx.x; 208 209 TxnWord val = TxnWord(); 210 TileDescriptor *descriptor = reinterpret_cast<TileDescriptor*>(&val); 211 212 if (tile_idx < num_tiles) 213 { 214 // Not-yet-set 215 descriptor->status = StatusWord(SCAN_TILE_INVALID); 216 d_tile_descriptors[TILE_STATUS_PADDING + tile_idx] = val; 217 } 218 219 if ((blockIdx.x == 0) && (threadIdx.x < TILE_STATUS_PADDING)) 220 { 221 // Padding 222 descriptor->status = StatusWord(SCAN_TILE_OOB); 223 d_tile_descriptors[threadIdx.x] = val; 224 } 225 } 226 227 228 /** 229 * Update the specified tile's inclusive value and corresponding status 230 */ SetInclusivecub::ScanTileState231 __device__ __forceinline__ void SetInclusive(int tile_idx, T tile_inclusive) 232 { 233 TileDescriptor tile_descriptor; 234 tile_descriptor.status = SCAN_TILE_INCLUSIVE; 235 tile_descriptor.value = tile_inclusive; 236 237 TxnWord alias; 238 *reinterpret_cast<TileDescriptor*>(&alias) = tile_descriptor; 239 ThreadStore<STORE_CG>(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx, alias); 240 } 241 242 243 /** 244 * Update the specified tile's partial value and corresponding status 245 */ SetPartialcub::ScanTileState246 __device__ __forceinline__ void SetPartial(int tile_idx, T tile_partial) 247 { 248 TileDescriptor tile_descriptor; 249 tile_descriptor.status = SCAN_TILE_PARTIAL; 250 tile_descriptor.value = tile_partial; 251 252 TxnWord alias; 253 *reinterpret_cast<TileDescriptor*>(&alias) = tile_descriptor; 254 ThreadStore<STORE_CG>(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx, alias); 255 } 256 257 /** 258 * Wait for the corresponding tile to become non-invalid 259 */ WaitForValidcub::ScanTileState260 __device__ __forceinline__ void WaitForValid( 261 int tile_idx, 262 StatusWord &status, 263 T &value) 264 { 265 TileDescriptor tile_descriptor; 266 do 267 { 268 __threadfence_block(); // prevent hoisting loads from loop 269 TxnWord alias = ThreadLoad<LOAD_CG>(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx); 270 tile_descriptor = reinterpret_cast<TileDescriptor&>(alias); 271 272 } while (WARP_ANY((tile_descriptor.status == SCAN_TILE_INVALID), 0xffffffff)); 273 274 status = tile_descriptor.status; 275 value = tile_descriptor.value; 276 } 277 278 }; 279 280 281 282 /** 283 * Tile status interface specialized for scan status and value types that 284 * cannot be combined into one machine word. 285 */ 286 template <typename T> 287 struct ScanTileState<T, false> 288 { 289 // Status word type 290 typedef char StatusWord; 291 292 // Constants 293 enum 294 { 295 TILE_STATUS_PADDING = CUB_PTX_WARP_THREADS, 296 }; 297 298 // Device storage 299 StatusWord *d_tile_status; 300 T *d_tile_partial; 301 T *d_tile_inclusive; 302 303 /// Constructor 304 __host__ __device__ __forceinline__ ScanTileStatecub::ScanTileState305 ScanTileState() 306 : 307 d_tile_status(NULL), 308 d_tile_partial(NULL), 309 d_tile_inclusive(NULL) 310 {} 311 312 313 /// Initializer 314 __host__ __device__ __forceinline__ Initcub::ScanTileState315 cudaError_t Init( 316 int num_tiles, ///< [in] Number of tiles 317 void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. 318 size_t temp_storage_bytes) ///< [in] Size in bytes of \t d_temp_storage allocation 319 { 320 cudaError_t error = cudaSuccess; 321 do 322 { 323 void* allocations[3]; 324 size_t allocation_sizes[3]; 325 326 allocation_sizes[0] = (num_tiles + TILE_STATUS_PADDING) * sizeof(StatusWord); // bytes needed for tile status descriptors 327 allocation_sizes[1] = (num_tiles + TILE_STATUS_PADDING) * sizeof(Uninitialized<T>); // bytes needed for partials 328 allocation_sizes[2] = (num_tiles + TILE_STATUS_PADDING) * sizeof(Uninitialized<T>); // bytes needed for inclusives 329 330 // Compute allocation pointers into the single storage blob 331 if (CubDebug(error = AliasTemporaries(d_temp_storage, temp_storage_bytes, allocations, allocation_sizes))) break; 332 333 // Alias the offsets 334 d_tile_status = reinterpret_cast<StatusWord*>(allocations[0]); 335 d_tile_partial = reinterpret_cast<T*>(allocations[1]); 336 d_tile_inclusive = reinterpret_cast<T*>(allocations[2]); 337 } 338 while (0); 339 340 return error; 341 } 342 343 344 /** 345 * Compute device memory needed for tile status 346 */ 347 __host__ __device__ __forceinline__ AllocationSizecub::ScanTileState348 static cudaError_t AllocationSize( 349 int num_tiles, ///< [in] Number of tiles 350 size_t &temp_storage_bytes) ///< [out] Size in bytes of \t d_temp_storage allocation 351 { 352 // Specify storage allocation requirements 353 size_t allocation_sizes[3]; 354 allocation_sizes[0] = (num_tiles + TILE_STATUS_PADDING) * sizeof(StatusWord); // bytes needed for tile status descriptors 355 allocation_sizes[1] = (num_tiles + TILE_STATUS_PADDING) * sizeof(Uninitialized<T>); // bytes needed for partials 356 allocation_sizes[2] = (num_tiles + TILE_STATUS_PADDING) * sizeof(Uninitialized<T>); // bytes needed for inclusives 357 358 // Set the necessary size of the blob 359 void* allocations[3]; 360 return CubDebug(AliasTemporaries(NULL, temp_storage_bytes, allocations, allocation_sizes)); 361 } 362 363 364 /** 365 * Initialize (from device) 366 */ InitializeStatuscub::ScanTileState367 __device__ __forceinline__ void InitializeStatus(int num_tiles) 368 { 369 int tile_idx = (blockIdx.x * blockDim.x) + threadIdx.x; 370 if (tile_idx < num_tiles) 371 { 372 // Not-yet-set 373 d_tile_status[TILE_STATUS_PADDING + tile_idx] = StatusWord(SCAN_TILE_INVALID); 374 } 375 376 if ((blockIdx.x == 0) && (threadIdx.x < TILE_STATUS_PADDING)) 377 { 378 // Padding 379 d_tile_status[threadIdx.x] = StatusWord(SCAN_TILE_OOB); 380 } 381 } 382 383 384 /** 385 * Update the specified tile's inclusive value and corresponding status 386 */ SetInclusivecub::ScanTileState387 __device__ __forceinline__ void SetInclusive(int tile_idx, T tile_inclusive) 388 { 389 // Update tile inclusive value 390 ThreadStore<STORE_CG>(d_tile_inclusive + TILE_STATUS_PADDING + tile_idx, tile_inclusive); 391 392 // Fence 393 __threadfence(); 394 395 // Update tile status 396 ThreadStore<STORE_CG>(d_tile_status + TILE_STATUS_PADDING + tile_idx, StatusWord(SCAN_TILE_INCLUSIVE)); 397 } 398 399 400 /** 401 * Update the specified tile's partial value and corresponding status 402 */ SetPartialcub::ScanTileState403 __device__ __forceinline__ void SetPartial(int tile_idx, T tile_partial) 404 { 405 // Update tile partial value 406 ThreadStore<STORE_CG>(d_tile_partial + TILE_STATUS_PADDING + tile_idx, tile_partial); 407 408 // Fence 409 __threadfence(); 410 411 // Update tile status 412 ThreadStore<STORE_CG>(d_tile_status + TILE_STATUS_PADDING + tile_idx, StatusWord(SCAN_TILE_PARTIAL)); 413 } 414 415 /** 416 * Wait for the corresponding tile to become non-invalid 417 */ WaitForValidcub::ScanTileState418 __device__ __forceinline__ void WaitForValid( 419 int tile_idx, 420 StatusWord &status, 421 T &value) 422 { 423 do { 424 status = ThreadLoad<LOAD_CG>(d_tile_status + TILE_STATUS_PADDING + tile_idx); 425 426 __threadfence(); // prevent hoisting loads from loop or loads below above this one 427 428 } while (status == SCAN_TILE_INVALID); 429 430 if (status == StatusWord(SCAN_TILE_PARTIAL)) 431 value = ThreadLoad<LOAD_CG>(d_tile_partial + TILE_STATUS_PADDING + tile_idx); 432 else 433 value = ThreadLoad<LOAD_CG>(d_tile_inclusive + TILE_STATUS_PADDING + tile_idx); 434 } 435 }; 436 437 438 /****************************************************************************** 439 * ReduceByKey tile status interface types for block-cooperative scans 440 ******************************************************************************/ 441 442 /** 443 * Tile status interface for reduction by key. 444 * 445 */ 446 template < 447 typename ValueT, 448 typename KeyT, 449 bool SINGLE_WORD = (Traits<ValueT>::PRIMITIVE) && (sizeof(ValueT) + sizeof(KeyT) < 16)> 450 struct ReduceByKeyScanTileState; 451 452 453 /** 454 * Tile status interface for reduction by key, specialized for scan status and value types that 455 * cannot be combined into one machine word. 456 */ 457 template < 458 typename ValueT, 459 typename KeyT> 460 struct ReduceByKeyScanTileState<ValueT, KeyT, false> : 461 ScanTileState<KeyValuePair<KeyT, ValueT> > 462 { 463 typedef ScanTileState<KeyValuePair<KeyT, ValueT> > SuperClass; 464 465 /// Constructor 466 __host__ __device__ __forceinline__ ReduceByKeyScanTileStatecub::ReduceByKeyScanTileState467 ReduceByKeyScanTileState() : SuperClass() {} 468 }; 469 470 471 /** 472 * Tile status interface for reduction by key, specialized for scan status and value types that 473 * can be combined into one machine word that can be read/written coherently in a single access. 474 */ 475 template < 476 typename ValueT, 477 typename KeyT> 478 struct ReduceByKeyScanTileState<ValueT, KeyT, true> 479 { 480 typedef KeyValuePair<KeyT, ValueT>KeyValuePairT; 481 482 // Constants 483 enum 484 { 485 PAIR_SIZE = sizeof(ValueT) + sizeof(KeyT), 486 TXN_WORD_SIZE = 1 << Log2<PAIR_SIZE + 1>::VALUE, 487 STATUS_WORD_SIZE = TXN_WORD_SIZE - PAIR_SIZE, 488 489 TILE_STATUS_PADDING = CUB_PTX_WARP_THREADS, 490 }; 491 492 // Status word type 493 typedef typename If<(STATUS_WORD_SIZE == 8), 494 long long, 495 typename If<(STATUS_WORD_SIZE == 4), 496 int, 497 typename If<(STATUS_WORD_SIZE == 2), 498 short, 499 char>::Type>::Type>::Type StatusWord; 500 501 // Status word type 502 typedef typename If<(TXN_WORD_SIZE == 16), 503 longlong2, 504 typename If<(TXN_WORD_SIZE == 8), 505 long long, 506 int>::Type>::Type TxnWord; 507 508 // Device word type (for when sizeof(ValueT) == sizeof(KeyT)) 509 struct TileDescriptorBigStatus 510 { 511 KeyT key; 512 ValueT value; 513 StatusWord status; 514 }; 515 516 // Device word type (for when sizeof(ValueT) != sizeof(KeyT)) 517 struct TileDescriptorLittleStatus 518 { 519 ValueT value; 520 StatusWord status; 521 KeyT key; 522 }; 523 524 // Device word type 525 typedef typename If< 526 (sizeof(ValueT) == sizeof(KeyT)), 527 TileDescriptorBigStatus, 528 TileDescriptorLittleStatus>::Type 529 TileDescriptor; 530 531 532 // Device storage 533 TxnWord *d_tile_descriptors; 534 535 536 /// Constructor 537 __host__ __device__ __forceinline__ ReduceByKeyScanTileStatecub::ReduceByKeyScanTileState538 ReduceByKeyScanTileState() 539 : 540 d_tile_descriptors(NULL) 541 {} 542 543 544 /// Initializer 545 __host__ __device__ __forceinline__ Initcub::ReduceByKeyScanTileState546 cudaError_t Init( 547 int /*num_tiles*/, ///< [in] Number of tiles 548 void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. 549 size_t /*temp_storage_bytes*/) ///< [in] Size in bytes of \t d_temp_storage allocation 550 { 551 d_tile_descriptors = reinterpret_cast<TxnWord*>(d_temp_storage); 552 return cudaSuccess; 553 } 554 555 556 /** 557 * Compute device memory needed for tile status 558 */ 559 __host__ __device__ __forceinline__ AllocationSizecub::ReduceByKeyScanTileState560 static cudaError_t AllocationSize( 561 int num_tiles, ///< [in] Number of tiles 562 size_t &temp_storage_bytes) ///< [out] Size in bytes of \t d_temp_storage allocation 563 { 564 temp_storage_bytes = (num_tiles + TILE_STATUS_PADDING) * sizeof(TileDescriptor); // bytes needed for tile status descriptors 565 return cudaSuccess; 566 } 567 568 569 /** 570 * Initialize (from device) 571 */ InitializeStatuscub::ReduceByKeyScanTileState572 __device__ __forceinline__ void InitializeStatus(int num_tiles) 573 { 574 int tile_idx = (blockIdx.x * blockDim.x) + threadIdx.x; 575 TxnWord val = TxnWord(); 576 TileDescriptor *descriptor = reinterpret_cast<TileDescriptor*>(&val); 577 578 if (tile_idx < num_tiles) 579 { 580 // Not-yet-set 581 descriptor->status = StatusWord(SCAN_TILE_INVALID); 582 d_tile_descriptors[TILE_STATUS_PADDING + tile_idx] = val; 583 } 584 585 if ((blockIdx.x == 0) && (threadIdx.x < TILE_STATUS_PADDING)) 586 { 587 // Padding 588 descriptor->status = StatusWord(SCAN_TILE_OOB); 589 d_tile_descriptors[threadIdx.x] = val; 590 } 591 } 592 593 594 /** 595 * Update the specified tile's inclusive value and corresponding status 596 */ SetInclusivecub::ReduceByKeyScanTileState597 __device__ __forceinline__ void SetInclusive(int tile_idx, KeyValuePairT tile_inclusive) 598 { 599 TileDescriptor tile_descriptor; 600 tile_descriptor.status = SCAN_TILE_INCLUSIVE; 601 tile_descriptor.value = tile_inclusive.value; 602 tile_descriptor.key = tile_inclusive.key; 603 604 TxnWord alias; 605 *reinterpret_cast<TileDescriptor*>(&alias) = tile_descriptor; 606 ThreadStore<STORE_CG>(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx, alias); 607 } 608 609 610 /** 611 * Update the specified tile's partial value and corresponding status 612 */ SetPartialcub::ReduceByKeyScanTileState613 __device__ __forceinline__ void SetPartial(int tile_idx, KeyValuePairT tile_partial) 614 { 615 TileDescriptor tile_descriptor; 616 tile_descriptor.status = SCAN_TILE_PARTIAL; 617 tile_descriptor.value = tile_partial.value; 618 tile_descriptor.key = tile_partial.key; 619 620 TxnWord alias; 621 *reinterpret_cast<TileDescriptor*>(&alias) = tile_descriptor; 622 ThreadStore<STORE_CG>(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx, alias); 623 } 624 625 /** 626 * Wait for the corresponding tile to become non-invalid 627 */ WaitForValidcub::ReduceByKeyScanTileState628 __device__ __forceinline__ void WaitForValid( 629 int tile_idx, 630 StatusWord &status, 631 KeyValuePairT &value) 632 { 633 // TxnWord alias = ThreadLoad<LOAD_CG>(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx); 634 // TileDescriptor tile_descriptor = reinterpret_cast<TileDescriptor&>(alias); 635 // 636 // while (tile_descriptor.status == SCAN_TILE_INVALID) 637 // { 638 // __threadfence_block(); // prevent hoisting loads from loop 639 // 640 // alias = ThreadLoad<LOAD_CG>(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx); 641 // tile_descriptor = reinterpret_cast<TileDescriptor&>(alias); 642 // } 643 // 644 // status = tile_descriptor.status; 645 // value.value = tile_descriptor.value; 646 // value.key = tile_descriptor.key; 647 648 TileDescriptor tile_descriptor; 649 do 650 { 651 __threadfence_block(); // prevent hoisting loads from loop 652 TxnWord alias = ThreadLoad<LOAD_CG>(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx); 653 tile_descriptor = reinterpret_cast<TileDescriptor&>(alias); 654 655 } while (WARP_ANY((tile_descriptor.status == SCAN_TILE_INVALID), 0xffffffff)); 656 657 status = tile_descriptor.status; 658 value.value = tile_descriptor.value; 659 value.key = tile_descriptor.key; 660 } 661 662 }; 663 664 665 /****************************************************************************** 666 * Prefix call-back operator for coupling local block scan within a 667 * block-cooperative scan 668 ******************************************************************************/ 669 670 /** 671 * Stateful block-scan prefix functor. Provides the the running prefix for 672 * the current tile by using the call-back warp to wait on on 673 * aggregates/prefixes from predecessor tiles to become available. 674 */ 675 template < 676 typename T, 677 typename ScanOpT, 678 typename ScanTileStateT, 679 int PTX_ARCH = CUB_PTX_ARCH> 680 struct TilePrefixCallbackOp 681 { 682 // Parameterized warp reduce 683 typedef WarpReduce<T, CUB_PTX_WARP_THREADS, PTX_ARCH> WarpReduceT; 684 685 // Temporary storage type 686 struct _TempStorage 687 { 688 typename WarpReduceT::TempStorage warp_reduce; 689 T exclusive_prefix; 690 T inclusive_prefix; 691 T block_aggregate; 692 }; 693 694 // Alias wrapper allowing temporary storage to be unioned 695 struct TempStorage : Uninitialized<_TempStorage> {}; 696 697 // Type of status word 698 typedef typename ScanTileStateT::StatusWord StatusWord; 699 700 // Fields 701 _TempStorage& temp_storage; ///< Reference to a warp-reduction instance 702 ScanTileStateT& tile_status; ///< Interface to tile status 703 ScanOpT scan_op; ///< Binary scan operator 704 int tile_idx; ///< The current tile index 705 T exclusive_prefix; ///< Exclusive prefix for the tile 706 T inclusive_prefix; ///< Inclusive prefix for the tile 707 708 // Constructor 709 __device__ __forceinline__ TilePrefixCallbackOpcub::TilePrefixCallbackOp710 TilePrefixCallbackOp( 711 ScanTileStateT &tile_status, 712 TempStorage &temp_storage, 713 ScanOpT scan_op, 714 int tile_idx) 715 : 716 temp_storage(temp_storage.Alias()), 717 tile_status(tile_status), 718 scan_op(scan_op), 719 tile_idx(tile_idx) {} 720 721 722 // Block until all predecessors within the warp-wide window have non-invalid status 723 __device__ __forceinline__ ProcessWindowcub::TilePrefixCallbackOp724 void ProcessWindow( 725 int predecessor_idx, ///< Preceding tile index to inspect 726 StatusWord &predecessor_status, ///< [out] Preceding tile status 727 T &window_aggregate) ///< [out] Relevant partial reduction from this window of preceding tiles 728 { 729 T value; 730 tile_status.WaitForValid(predecessor_idx, predecessor_status, value); 731 732 // Perform a segmented reduction to get the prefix for the current window. 733 // Use the swizzled scan operator because we are now scanning *down* towards thread0. 734 735 int tail_flag = (predecessor_status == StatusWord(SCAN_TILE_INCLUSIVE)); 736 window_aggregate = WarpReduceT(temp_storage.warp_reduce).TailSegmentedReduce( 737 value, 738 tail_flag, 739 SwizzleScanOp<ScanOpT>(scan_op)); 740 } 741 742 743 // BlockScan prefix callback functor (called by the first warp) 744 __device__ __forceinline__ operator ()cub::TilePrefixCallbackOp745 T operator()(T block_aggregate) 746 { 747 748 // Update our status with our tile-aggregate 749 if (threadIdx.x == 0) 750 { 751 temp_storage.block_aggregate = block_aggregate; 752 tile_status.SetPartial(tile_idx, block_aggregate); 753 } 754 755 int predecessor_idx = tile_idx - threadIdx.x - 1; 756 StatusWord predecessor_status; 757 T window_aggregate; 758 759 // Wait for the warp-wide window of predecessor tiles to become valid 760 ProcessWindow(predecessor_idx, predecessor_status, window_aggregate); 761 762 // The exclusive tile prefix starts out as the current window aggregate 763 exclusive_prefix = window_aggregate; 764 765 // Keep sliding the window back until we come across a tile whose inclusive prefix is known 766 while (WARP_ALL((predecessor_status != StatusWord(SCAN_TILE_INCLUSIVE)), 0xffffffff)) 767 { 768 predecessor_idx -= CUB_PTX_WARP_THREADS; 769 770 // Update exclusive tile prefix with the window prefix 771 ProcessWindow(predecessor_idx, predecessor_status, window_aggregate); 772 exclusive_prefix = scan_op(window_aggregate, exclusive_prefix); 773 } 774 775 // Compute the inclusive tile prefix and update the status for this tile 776 if (threadIdx.x == 0) 777 { 778 inclusive_prefix = scan_op(exclusive_prefix, block_aggregate); 779 tile_status.SetInclusive(tile_idx, inclusive_prefix); 780 781 temp_storage.exclusive_prefix = exclusive_prefix; 782 temp_storage.inclusive_prefix = inclusive_prefix; 783 } 784 785 // Return exclusive_prefix 786 return exclusive_prefix; 787 } 788 789 // Get the exclusive prefix stored in temporary storage 790 __device__ __forceinline__ GetExclusivePrefixcub::TilePrefixCallbackOp791 T GetExclusivePrefix() 792 { 793 return temp_storage.exclusive_prefix; 794 } 795 796 // Get the inclusive prefix stored in temporary storage 797 __device__ __forceinline__ GetInclusivePrefixcub::TilePrefixCallbackOp798 T GetInclusivePrefix() 799 { 800 return temp_storage.inclusive_prefix; 801 } 802 803 // Get the block aggregate stored in temporary storage 804 __device__ __forceinline__ GetBlockAggregatecub::TilePrefixCallbackOp805 T GetBlockAggregate() 806 { 807 return temp_storage.block_aggregate; 808 } 809 810 }; 811 812 813 } // CUB namespace 814 CUB_NS_POSTFIX // Optional outer namespace(s) 815 816