1 /******************************************************************************
2  * Copyright (c) 2011, Duane Merrill.  All rights reserved.
3  * Copyright (c) 2011-2018, NVIDIA CORPORATION.  All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are met:
7  *     * Redistributions of source code must retain the above copyright
8  *       notice, this list of conditions and the following disclaimer.
9  *     * Redistributions in binary form must reproduce the above copyright
10  *       notice, this list of conditions and the following disclaimer in the
11  *       documentation and/or other materials provided with the distribution.
12  *     * Neither the name of the NVIDIA CORPORATION nor the
13  *       names of its contributors may be used to endorse or promote products
14  *       derived from this software without specific prior written permission.
15  *
16  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
17  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19  * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
20  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26  *
27  ******************************************************************************/
28 
29 /**
30  * \file
31  * Callback operator types for supplying BlockScan prefixes
32  */
33 
34 #pragma once
35 
36 #include <iterator>
37 
38 #include "../thread/thread_load.cuh"
39 #include "../thread/thread_store.cuh"
40 #include "../warp/warp_reduce.cuh"
41 #include "../util_arch.cuh"
42 #include "../util_device.cuh"
43 #include "../util_namespace.cuh"
44 
45 /// Optional outer namespace(s)
46 CUB_NS_PREFIX
47 
48 /// CUB namespace
49 namespace cub {
50 
51 
52 /******************************************************************************
53  * Prefix functor type for maintaining a running prefix while scanning a
54  * region independent of other thread blocks
55  ******************************************************************************/
56 
57 /**
58  * Stateful callback operator type for supplying BlockScan prefixes.
59  * Maintains a running prefix that can be applied to consecutive
60  * BlockScan operations.
61  */
62 template <
63     typename T,                 ///< BlockScan value type
64     typename ScanOpT>            ///< Wrapped scan operator type
65 struct BlockScanRunningPrefixOp
66 {
67     ScanOpT     op;                 ///< Wrapped scan operator
68     T           running_total;      ///< Running block-wide prefix
69 
70     /// Constructor
BlockScanRunningPrefixOpcub::BlockScanRunningPrefixOp71     __device__ __forceinline__ BlockScanRunningPrefixOp(ScanOpT op)
72     :
73         op(op)
74     {}
75 
76     /// Constructor
BlockScanRunningPrefixOpcub::BlockScanRunningPrefixOp77     __device__ __forceinline__ BlockScanRunningPrefixOp(
78         T starting_prefix,
79         ScanOpT op)
80     :
81         op(op),
82         running_total(starting_prefix)
83     {}
84 
85     /**
86      * Prefix callback operator.  Returns the block-wide running_total in thread-0.
87      */
operator ()cub::BlockScanRunningPrefixOp88     __device__ __forceinline__ T operator()(
89         const T &block_aggregate)              ///< The aggregate sum of the BlockScan inputs
90     {
91         T retval = running_total;
92         running_total = op(running_total, block_aggregate);
93         return retval;
94     }
95 };
96 
97 
98 /******************************************************************************
99  * Generic tile status interface types for block-cooperative scans
100  ******************************************************************************/
101 
102 /**
103  * Enumerations of tile status
104  */
105 enum ScanTileStatus
106 {
107     SCAN_TILE_OOB,          // Out-of-bounds (e.g., padding)
108     SCAN_TILE_INVALID = 99, // Not yet processed
109     SCAN_TILE_PARTIAL,      // Tile aggregate is available
110     SCAN_TILE_INCLUSIVE,    // Inclusive tile prefix is available
111 };
112 
113 
114 /**
115  * Tile status interface.
116  */
117 template <
118     typename    T,
119     bool        SINGLE_WORD = Traits<T>::PRIMITIVE>
120 struct ScanTileState;
121 
122 
123 /**
124  * Tile status interface specialized for scan status and value types
125  * that can be combined into one machine word that can be
126  * read/written coherently in a single access.
127  */
128 template <typename T>
129 struct ScanTileState<T, true>
130 {
131     // Status word type
132     typedef typename If<(sizeof(T) == 8),
133         long long,
134         typename If<(sizeof(T) == 4),
135             int,
136             typename If<(sizeof(T) == 2),
137                 short,
138                 char>::Type>::Type>::Type StatusWord;
139 
140 
141     // Unit word type
142     typedef typename If<(sizeof(T) == 8),
143         longlong2,
144         typename If<(sizeof(T) == 4),
145             int2,
146             typename If<(sizeof(T) == 2),
147                 int,
148                 uchar2>::Type>::Type>::Type TxnWord;
149 
150 
151     // Device word type
152     struct TileDescriptor
153     {
154         StatusWord  status;
155         T           value;
156     };
157 
158 
159     // Constants
160     enum
161     {
162         TILE_STATUS_PADDING = CUB_PTX_WARP_THREADS,
163     };
164 
165 
166     // Device storage
167     TxnWord *d_tile_descriptors;
168 
169     /// Constructor
170     __host__ __device__ __forceinline__
ScanTileStatecub::ScanTileState171     ScanTileState()
172     :
173         d_tile_descriptors(NULL)
174     {}
175 
176 
177     /// Initializer
178     __host__ __device__ __forceinline__
Initcub::ScanTileState179     cudaError_t Init(
180         int     /*num_tiles*/,                      ///< [in] Number of tiles
181         void    *d_temp_storage,                    ///< [in] %Device-accessible allocation of temporary storage.  When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done.
182         size_t  /*temp_storage_bytes*/)             ///< [in] Size in bytes of \t d_temp_storage allocation
183     {
184         d_tile_descriptors = reinterpret_cast<TxnWord*>(d_temp_storage);
185         return cudaSuccess;
186     }
187 
188 
189     /**
190      * Compute device memory needed for tile status
191      */
192     __host__ __device__ __forceinline__
AllocationSizecub::ScanTileState193     static cudaError_t AllocationSize(
194         int     num_tiles,                          ///< [in] Number of tiles
195         size_t  &temp_storage_bytes)                ///< [out] Size in bytes of \t d_temp_storage allocation
196     {
197         temp_storage_bytes = (num_tiles + TILE_STATUS_PADDING) * sizeof(TileDescriptor);       // bytes needed for tile status descriptors
198         return cudaSuccess;
199     }
200 
201 
202     /**
203      * Initialize (from device)
204      */
InitializeStatuscub::ScanTileState205     __device__ __forceinline__ void InitializeStatus(int num_tiles)
206     {
207         int tile_idx = (blockIdx.x * blockDim.x) + threadIdx.x;
208 
209         TxnWord val = TxnWord();
210         TileDescriptor *descriptor = reinterpret_cast<TileDescriptor*>(&val);
211 
212         if (tile_idx < num_tiles)
213         {
214             // Not-yet-set
215             descriptor->status = StatusWord(SCAN_TILE_INVALID);
216             d_tile_descriptors[TILE_STATUS_PADDING + tile_idx] = val;
217         }
218 
219         if ((blockIdx.x == 0) && (threadIdx.x < TILE_STATUS_PADDING))
220         {
221             // Padding
222             descriptor->status = StatusWord(SCAN_TILE_OOB);
223             d_tile_descriptors[threadIdx.x] = val;
224         }
225     }
226 
227 
228     /**
229      * Update the specified tile's inclusive value and corresponding status
230      */
SetInclusivecub::ScanTileState231     __device__ __forceinline__ void SetInclusive(int tile_idx, T tile_inclusive)
232     {
233         TileDescriptor tile_descriptor;
234         tile_descriptor.status = SCAN_TILE_INCLUSIVE;
235         tile_descriptor.value = tile_inclusive;
236 
237         TxnWord alias;
238         *reinterpret_cast<TileDescriptor*>(&alias) = tile_descriptor;
239         ThreadStore<STORE_CG>(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx, alias);
240     }
241 
242 
243     /**
244      * Update the specified tile's partial value and corresponding status
245      */
SetPartialcub::ScanTileState246     __device__ __forceinline__ void SetPartial(int tile_idx, T tile_partial)
247     {
248         TileDescriptor tile_descriptor;
249         tile_descriptor.status = SCAN_TILE_PARTIAL;
250         tile_descriptor.value = tile_partial;
251 
252         TxnWord alias;
253         *reinterpret_cast<TileDescriptor*>(&alias) = tile_descriptor;
254         ThreadStore<STORE_CG>(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx, alias);
255     }
256 
257     /**
258      * Wait for the corresponding tile to become non-invalid
259      */
WaitForValidcub::ScanTileState260     __device__ __forceinline__ void WaitForValid(
261         int             tile_idx,
262         StatusWord      &status,
263         T               &value)
264     {
265         TileDescriptor tile_descriptor;
266         do
267         {
268             __threadfence_block(); // prevent hoisting loads from loop
269             TxnWord alias = ThreadLoad<LOAD_CG>(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx);
270             tile_descriptor = reinterpret_cast<TileDescriptor&>(alias);
271 
272         } while (WARP_ANY((tile_descriptor.status == SCAN_TILE_INVALID), 0xffffffff));
273 
274         status = tile_descriptor.status;
275         value = tile_descriptor.value;
276     }
277 
278 };
279 
280 
281 
282 /**
283  * Tile status interface specialized for scan status and value types that
284  * cannot be combined into one machine word.
285  */
286 template <typename T>
287 struct ScanTileState<T, false>
288 {
289     // Status word type
290     typedef char StatusWord;
291 
292     // Constants
293     enum
294     {
295         TILE_STATUS_PADDING = CUB_PTX_WARP_THREADS,
296     };
297 
298     // Device storage
299     StatusWord  *d_tile_status;
300     T           *d_tile_partial;
301     T           *d_tile_inclusive;
302 
303     /// Constructor
304     __host__ __device__ __forceinline__
ScanTileStatecub::ScanTileState305     ScanTileState()
306     :
307         d_tile_status(NULL),
308         d_tile_partial(NULL),
309         d_tile_inclusive(NULL)
310     {}
311 
312 
313     /// Initializer
314     __host__ __device__ __forceinline__
Initcub::ScanTileState315     cudaError_t Init(
316         int     num_tiles,                          ///< [in] Number of tiles
317         void    *d_temp_storage,                    ///< [in] %Device-accessible allocation of temporary storage.  When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done.
318         size_t  temp_storage_bytes)                 ///< [in] Size in bytes of \t d_temp_storage allocation
319     {
320         cudaError_t error = cudaSuccess;
321         do
322         {
323             void*   allocations[3];
324             size_t  allocation_sizes[3];
325 
326             allocation_sizes[0] = (num_tiles + TILE_STATUS_PADDING) * sizeof(StatusWord);           // bytes needed for tile status descriptors
327             allocation_sizes[1] = (num_tiles + TILE_STATUS_PADDING) * sizeof(Uninitialized<T>);     // bytes needed for partials
328             allocation_sizes[2] = (num_tiles + TILE_STATUS_PADDING) * sizeof(Uninitialized<T>);     // bytes needed for inclusives
329 
330             // Compute allocation pointers into the single storage blob
331             if (CubDebug(error = AliasTemporaries(d_temp_storage, temp_storage_bytes, allocations, allocation_sizes))) break;
332 
333             // Alias the offsets
334             d_tile_status       = reinterpret_cast<StatusWord*>(allocations[0]);
335             d_tile_partial      = reinterpret_cast<T*>(allocations[1]);
336             d_tile_inclusive    = reinterpret_cast<T*>(allocations[2]);
337         }
338         while (0);
339 
340         return error;
341     }
342 
343 
344     /**
345      * Compute device memory needed for tile status
346      */
347     __host__ __device__ __forceinline__
AllocationSizecub::ScanTileState348     static cudaError_t AllocationSize(
349         int     num_tiles,                          ///< [in] Number of tiles
350         size_t  &temp_storage_bytes)                ///< [out] Size in bytes of \t d_temp_storage allocation
351     {
352         // Specify storage allocation requirements
353         size_t  allocation_sizes[3];
354         allocation_sizes[0] = (num_tiles + TILE_STATUS_PADDING) * sizeof(StatusWord);         // bytes needed for tile status descriptors
355         allocation_sizes[1] = (num_tiles + TILE_STATUS_PADDING) * sizeof(Uninitialized<T>);   // bytes needed for partials
356         allocation_sizes[2] = (num_tiles + TILE_STATUS_PADDING) * sizeof(Uninitialized<T>);   // bytes needed for inclusives
357 
358         // Set the necessary size of the blob
359         void* allocations[3];
360         return CubDebug(AliasTemporaries(NULL, temp_storage_bytes, allocations, allocation_sizes));
361     }
362 
363 
364     /**
365      * Initialize (from device)
366      */
InitializeStatuscub::ScanTileState367     __device__ __forceinline__ void InitializeStatus(int num_tiles)
368     {
369         int tile_idx = (blockIdx.x * blockDim.x) + threadIdx.x;
370         if (tile_idx < num_tiles)
371         {
372             // Not-yet-set
373             d_tile_status[TILE_STATUS_PADDING + tile_idx] = StatusWord(SCAN_TILE_INVALID);
374         }
375 
376         if ((blockIdx.x == 0) && (threadIdx.x < TILE_STATUS_PADDING))
377         {
378             // Padding
379             d_tile_status[threadIdx.x] = StatusWord(SCAN_TILE_OOB);
380         }
381     }
382 
383 
384     /**
385      * Update the specified tile's inclusive value and corresponding status
386      */
SetInclusivecub::ScanTileState387     __device__ __forceinline__ void SetInclusive(int tile_idx, T tile_inclusive)
388     {
389         // Update tile inclusive value
390         ThreadStore<STORE_CG>(d_tile_inclusive + TILE_STATUS_PADDING + tile_idx, tile_inclusive);
391 
392         // Fence
393         __threadfence();
394 
395         // Update tile status
396         ThreadStore<STORE_CG>(d_tile_status + TILE_STATUS_PADDING + tile_idx, StatusWord(SCAN_TILE_INCLUSIVE));
397     }
398 
399 
400     /**
401      * Update the specified tile's partial value and corresponding status
402      */
SetPartialcub::ScanTileState403     __device__ __forceinline__ void SetPartial(int tile_idx, T tile_partial)
404     {
405         // Update tile partial value
406         ThreadStore<STORE_CG>(d_tile_partial + TILE_STATUS_PADDING + tile_idx, tile_partial);
407 
408         // Fence
409         __threadfence();
410 
411         // Update tile status
412         ThreadStore<STORE_CG>(d_tile_status + TILE_STATUS_PADDING + tile_idx, StatusWord(SCAN_TILE_PARTIAL));
413     }
414 
415     /**
416      * Wait for the corresponding tile to become non-invalid
417      */
WaitForValidcub::ScanTileState418     __device__ __forceinline__ void WaitForValid(
419         int             tile_idx,
420         StatusWord      &status,
421         T               &value)
422     {
423         do {
424             status = ThreadLoad<LOAD_CG>(d_tile_status + TILE_STATUS_PADDING + tile_idx);
425 
426             __threadfence();    // prevent hoisting loads from loop or loads below above this one
427 
428         } while (status == SCAN_TILE_INVALID);
429 
430         if (status == StatusWord(SCAN_TILE_PARTIAL))
431             value = ThreadLoad<LOAD_CG>(d_tile_partial + TILE_STATUS_PADDING + tile_idx);
432         else
433             value = ThreadLoad<LOAD_CG>(d_tile_inclusive + TILE_STATUS_PADDING + tile_idx);
434     }
435 };
436 
437 
438 /******************************************************************************
439  * ReduceByKey tile status interface types for block-cooperative scans
440  ******************************************************************************/
441 
442 /**
443  * Tile status interface for reduction by key.
444  *
445  */
446 template <
447     typename    ValueT,
448     typename    KeyT,
449     bool        SINGLE_WORD = (Traits<ValueT>::PRIMITIVE) && (sizeof(ValueT) + sizeof(KeyT) < 16)>
450 struct ReduceByKeyScanTileState;
451 
452 
453 /**
454  * Tile status interface for reduction by key, specialized for scan status and value types that
455  * cannot be combined into one machine word.
456  */
457 template <
458     typename    ValueT,
459     typename    KeyT>
460 struct ReduceByKeyScanTileState<ValueT, KeyT, false> :
461     ScanTileState<KeyValuePair<KeyT, ValueT> >
462 {
463     typedef ScanTileState<KeyValuePair<KeyT, ValueT> > SuperClass;
464 
465     /// Constructor
466     __host__ __device__ __forceinline__
ReduceByKeyScanTileStatecub::ReduceByKeyScanTileState467     ReduceByKeyScanTileState() : SuperClass() {}
468 };
469 
470 
471 /**
472  * Tile status interface for reduction by key, specialized for scan status and value types that
473  * can be combined into one machine word that can be read/written coherently in a single access.
474  */
475 template <
476     typename ValueT,
477     typename KeyT>
478 struct ReduceByKeyScanTileState<ValueT, KeyT, true>
479 {
480     typedef KeyValuePair<KeyT, ValueT>KeyValuePairT;
481 
482     // Constants
483     enum
484     {
485         PAIR_SIZE           = sizeof(ValueT) + sizeof(KeyT),
486         TXN_WORD_SIZE       = 1 << Log2<PAIR_SIZE + 1>::VALUE,
487         STATUS_WORD_SIZE    = TXN_WORD_SIZE - PAIR_SIZE,
488 
489         TILE_STATUS_PADDING = CUB_PTX_WARP_THREADS,
490     };
491 
492     // Status word type
493     typedef typename If<(STATUS_WORD_SIZE == 8),
494         long long,
495         typename If<(STATUS_WORD_SIZE == 4),
496             int,
497             typename If<(STATUS_WORD_SIZE == 2),
498                 short,
499                 char>::Type>::Type>::Type StatusWord;
500 
501     // Status word type
502     typedef typename If<(TXN_WORD_SIZE == 16),
503         longlong2,
504         typename If<(TXN_WORD_SIZE == 8),
505             long long,
506             int>::Type>::Type TxnWord;
507 
508     // Device word type (for when sizeof(ValueT) == sizeof(KeyT))
509     struct TileDescriptorBigStatus
510     {
511         KeyT        key;
512         ValueT      value;
513         StatusWord  status;
514     };
515 
516     // Device word type (for when sizeof(ValueT) != sizeof(KeyT))
517     struct TileDescriptorLittleStatus
518     {
519         ValueT      value;
520         StatusWord  status;
521         KeyT        key;
522     };
523 
524     // Device word type
525     typedef typename If<
526             (sizeof(ValueT) == sizeof(KeyT)),
527             TileDescriptorBigStatus,
528             TileDescriptorLittleStatus>::Type
529         TileDescriptor;
530 
531 
532     // Device storage
533     TxnWord *d_tile_descriptors;
534 
535 
536     /// Constructor
537     __host__ __device__ __forceinline__
ReduceByKeyScanTileStatecub::ReduceByKeyScanTileState538     ReduceByKeyScanTileState()
539     :
540         d_tile_descriptors(NULL)
541     {}
542 
543 
544     /// Initializer
545     __host__ __device__ __forceinline__
Initcub::ReduceByKeyScanTileState546     cudaError_t Init(
547         int     /*num_tiles*/,                      ///< [in] Number of tiles
548         void    *d_temp_storage,                    ///< [in] %Device-accessible allocation of temporary storage.  When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done.
549         size_t  /*temp_storage_bytes*/)             ///< [in] Size in bytes of \t d_temp_storage allocation
550     {
551         d_tile_descriptors = reinterpret_cast<TxnWord*>(d_temp_storage);
552         return cudaSuccess;
553     }
554 
555 
556     /**
557      * Compute device memory needed for tile status
558      */
559     __host__ __device__ __forceinline__
AllocationSizecub::ReduceByKeyScanTileState560     static cudaError_t AllocationSize(
561         int     num_tiles,                          ///< [in] Number of tiles
562         size_t  &temp_storage_bytes)                ///< [out] Size in bytes of \t d_temp_storage allocation
563     {
564         temp_storage_bytes = (num_tiles + TILE_STATUS_PADDING) * sizeof(TileDescriptor);       // bytes needed for tile status descriptors
565         return cudaSuccess;
566     }
567 
568 
569     /**
570      * Initialize (from device)
571      */
InitializeStatuscub::ReduceByKeyScanTileState572     __device__ __forceinline__ void InitializeStatus(int num_tiles)
573     {
574         int             tile_idx    = (blockIdx.x * blockDim.x) + threadIdx.x;
575         TxnWord         val         = TxnWord();
576         TileDescriptor  *descriptor = reinterpret_cast<TileDescriptor*>(&val);
577 
578         if (tile_idx < num_tiles)
579         {
580             // Not-yet-set
581             descriptor->status = StatusWord(SCAN_TILE_INVALID);
582             d_tile_descriptors[TILE_STATUS_PADDING + tile_idx] = val;
583         }
584 
585         if ((blockIdx.x == 0) && (threadIdx.x < TILE_STATUS_PADDING))
586         {
587             // Padding
588             descriptor->status = StatusWord(SCAN_TILE_OOB);
589             d_tile_descriptors[threadIdx.x] = val;
590         }
591     }
592 
593 
594     /**
595      * Update the specified tile's inclusive value and corresponding status
596      */
SetInclusivecub::ReduceByKeyScanTileState597     __device__ __forceinline__ void SetInclusive(int tile_idx, KeyValuePairT tile_inclusive)
598     {
599         TileDescriptor tile_descriptor;
600         tile_descriptor.status  = SCAN_TILE_INCLUSIVE;
601         tile_descriptor.value   = tile_inclusive.value;
602         tile_descriptor.key     = tile_inclusive.key;
603 
604         TxnWord alias;
605         *reinterpret_cast<TileDescriptor*>(&alias) = tile_descriptor;
606         ThreadStore<STORE_CG>(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx, alias);
607     }
608 
609 
610     /**
611      * Update the specified tile's partial value and corresponding status
612      */
SetPartialcub::ReduceByKeyScanTileState613     __device__ __forceinline__ void SetPartial(int tile_idx, KeyValuePairT tile_partial)
614     {
615         TileDescriptor tile_descriptor;
616         tile_descriptor.status  = SCAN_TILE_PARTIAL;
617         tile_descriptor.value   = tile_partial.value;
618         tile_descriptor.key     = tile_partial.key;
619 
620         TxnWord alias;
621         *reinterpret_cast<TileDescriptor*>(&alias) = tile_descriptor;
622         ThreadStore<STORE_CG>(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx, alias);
623     }
624 
625     /**
626      * Wait for the corresponding tile to become non-invalid
627      */
WaitForValidcub::ReduceByKeyScanTileState628     __device__ __forceinline__ void WaitForValid(
629         int                     tile_idx,
630         StatusWord              &status,
631         KeyValuePairT           &value)
632     {
633 //        TxnWord         alias           = ThreadLoad<LOAD_CG>(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx);
634 //        TileDescriptor  tile_descriptor = reinterpret_cast<TileDescriptor&>(alias);
635 //
636 //        while (tile_descriptor.status == SCAN_TILE_INVALID)
637 //        {
638 //            __threadfence_block(); // prevent hoisting loads from loop
639 //
640 //            alias           = ThreadLoad<LOAD_CG>(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx);
641 //            tile_descriptor = reinterpret_cast<TileDescriptor&>(alias);
642 //        }
643 //
644 //        status      = tile_descriptor.status;
645 //        value.value = tile_descriptor.value;
646 //        value.key   = tile_descriptor.key;
647 
648         TileDescriptor tile_descriptor;
649         do
650         {
651             __threadfence_block(); // prevent hoisting loads from loop
652             TxnWord alias = ThreadLoad<LOAD_CG>(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx);
653             tile_descriptor = reinterpret_cast<TileDescriptor&>(alias);
654 
655         } while (WARP_ANY((tile_descriptor.status == SCAN_TILE_INVALID), 0xffffffff));
656 
657         status      = tile_descriptor.status;
658         value.value = tile_descriptor.value;
659         value.key   = tile_descriptor.key;
660     }
661 
662 };
663 
664 
665 /******************************************************************************
666  * Prefix call-back operator for coupling local block scan within a
667  * block-cooperative scan
668  ******************************************************************************/
669 
670 /**
671  * Stateful block-scan prefix functor.  Provides the the running prefix for
672  * the current tile by using the call-back warp to wait on on
673  * aggregates/prefixes from predecessor tiles to become available.
674  */
675 template <
676     typename    T,
677     typename    ScanOpT,
678     typename    ScanTileStateT,
679     int         PTX_ARCH = CUB_PTX_ARCH>
680 struct TilePrefixCallbackOp
681 {
682     // Parameterized warp reduce
683     typedef WarpReduce<T, CUB_PTX_WARP_THREADS, PTX_ARCH> WarpReduceT;
684 
685     // Temporary storage type
686     struct _TempStorage
687     {
688         typename WarpReduceT::TempStorage   warp_reduce;
689         T                                   exclusive_prefix;
690         T                                   inclusive_prefix;
691         T                                   block_aggregate;
692     };
693 
694     // Alias wrapper allowing temporary storage to be unioned
695     struct TempStorage : Uninitialized<_TempStorage> {};
696 
697     // Type of status word
698     typedef typename ScanTileStateT::StatusWord StatusWord;
699 
700     // Fields
701     _TempStorage&               temp_storage;       ///< Reference to a warp-reduction instance
702     ScanTileStateT&             tile_status;        ///< Interface to tile status
703     ScanOpT                     scan_op;            ///< Binary scan operator
704     int                         tile_idx;           ///< The current tile index
705     T                           exclusive_prefix;   ///< Exclusive prefix for the tile
706     T                           inclusive_prefix;   ///< Inclusive prefix for the tile
707 
708     // Constructor
709     __device__ __forceinline__
TilePrefixCallbackOpcub::TilePrefixCallbackOp710     TilePrefixCallbackOp(
711         ScanTileStateT       &tile_status,
712         TempStorage         &temp_storage,
713         ScanOpT              scan_op,
714         int                 tile_idx)
715     :
716         temp_storage(temp_storage.Alias()),
717         tile_status(tile_status),
718         scan_op(scan_op),
719         tile_idx(tile_idx) {}
720 
721 
722     // Block until all predecessors within the warp-wide window have non-invalid status
723     __device__ __forceinline__
ProcessWindowcub::TilePrefixCallbackOp724     void ProcessWindow(
725         int         predecessor_idx,        ///< Preceding tile index to inspect
726         StatusWord  &predecessor_status,    ///< [out] Preceding tile status
727         T           &window_aggregate)      ///< [out] Relevant partial reduction from this window of preceding tiles
728     {
729         T value;
730         tile_status.WaitForValid(predecessor_idx, predecessor_status, value);
731 
732         // Perform a segmented reduction to get the prefix for the current window.
733         // Use the swizzled scan operator because we are now scanning *down* towards thread0.
734 
735         int tail_flag = (predecessor_status == StatusWord(SCAN_TILE_INCLUSIVE));
736         window_aggregate = WarpReduceT(temp_storage.warp_reduce).TailSegmentedReduce(
737             value,
738             tail_flag,
739             SwizzleScanOp<ScanOpT>(scan_op));
740     }
741 
742 
743     // BlockScan prefix callback functor (called by the first warp)
744     __device__ __forceinline__
operator ()cub::TilePrefixCallbackOp745     T operator()(T block_aggregate)
746     {
747 
748         // Update our status with our tile-aggregate
749         if (threadIdx.x == 0)
750         {
751             temp_storage.block_aggregate = block_aggregate;
752             tile_status.SetPartial(tile_idx, block_aggregate);
753         }
754 
755         int         predecessor_idx = tile_idx - threadIdx.x - 1;
756         StatusWord  predecessor_status;
757         T           window_aggregate;
758 
759         // Wait for the warp-wide window of predecessor tiles to become valid
760         ProcessWindow(predecessor_idx, predecessor_status, window_aggregate);
761 
762         // The exclusive tile prefix starts out as the current window aggregate
763         exclusive_prefix = window_aggregate;
764 
765         // Keep sliding the window back until we come across a tile whose inclusive prefix is known
766         while (WARP_ALL((predecessor_status != StatusWord(SCAN_TILE_INCLUSIVE)), 0xffffffff))
767         {
768             predecessor_idx -= CUB_PTX_WARP_THREADS;
769 
770             // Update exclusive tile prefix with the window prefix
771             ProcessWindow(predecessor_idx, predecessor_status, window_aggregate);
772             exclusive_prefix = scan_op(window_aggregate, exclusive_prefix);
773         }
774 
775         // Compute the inclusive tile prefix and update the status for this tile
776         if (threadIdx.x == 0)
777         {
778             inclusive_prefix = scan_op(exclusive_prefix, block_aggregate);
779             tile_status.SetInclusive(tile_idx, inclusive_prefix);
780 
781             temp_storage.exclusive_prefix = exclusive_prefix;
782             temp_storage.inclusive_prefix = inclusive_prefix;
783         }
784 
785         // Return exclusive_prefix
786         return exclusive_prefix;
787     }
788 
789     // Get the exclusive prefix stored in temporary storage
790     __device__ __forceinline__
GetExclusivePrefixcub::TilePrefixCallbackOp791     T GetExclusivePrefix()
792     {
793         return temp_storage.exclusive_prefix;
794     }
795 
796     // Get the inclusive prefix stored in temporary storage
797     __device__ __forceinline__
GetInclusivePrefixcub::TilePrefixCallbackOp798     T GetInclusivePrefix()
799     {
800         return temp_storage.inclusive_prefix;
801     }
802 
803     // Get the block aggregate stored in temporary storage
804     __device__ __forceinline__
GetBlockAggregatecub::TilePrefixCallbackOp805     T GetBlockAggregate()
806     {
807         return temp_storage.block_aggregate;
808     }
809 
810 };
811 
812 
813 }               // CUB namespace
814 CUB_NS_POSTFIX  // Optional outer namespace(s)
815 
816