1 /****************************************************************************** 2 * Copyright (c) 2011, Duane Merrill. All rights reserved. 3 * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. 4 * 5 * Redistribution and use in source and binary forms, with or without 6 * modification, are permitted provided that the following conditions are met: 7 * * Redistributions of source code must retain the above copyright 8 * notice, this list of conditions and the following disclaimer. 9 * * Redistributions in binary form must reproduce the above copyright 10 * notice, this list of conditions and the following disclaimer in the 11 * documentation and/or other materials provided with the distribution. 12 * * Neither the name of the NVIDIA CORPORATION nor the 13 * names of its contributors may be used to endorse or promote products 14 * derived from this software without specific prior written permission. 15 * 16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 17 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 18 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 20 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 21 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 22 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 23 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 25 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 * 27 ******************************************************************************/ 28 29 /** 30 * \file 31 * cub::AgentSegmentFixup implements a stateful abstraction of CUDA thread blocks for participating in device-wide reduce-value-by-key. 32 */ 33 34 #pragma once 35 36 #include <iterator> 37 38 #include "single_pass_scan_operators.cuh" 39 #include "../block/block_load.cuh" 40 #include "../block/block_store.cuh" 41 #include "../block/block_scan.cuh" 42 #include "../block/block_discontinuity.cuh" 43 #include "../iterator/cache_modified_input_iterator.cuh" 44 #include "../iterator/constant_input_iterator.cuh" 45 #include "../util_namespace.cuh" 46 47 /// Optional outer namespace(s) 48 CUB_NS_PREFIX 49 50 /// CUB namespace 51 namespace cub { 52 53 54 /****************************************************************************** 55 * Tuning policy types 56 ******************************************************************************/ 57 58 /** 59 * Parameterizable tuning policy type for AgentSegmentFixup 60 */ 61 template < 62 int _BLOCK_THREADS, ///< Threads per thread block 63 int _ITEMS_PER_THREAD, ///< Items per thread (per tile of input) 64 BlockLoadAlgorithm _LOAD_ALGORITHM, ///< The BlockLoad algorithm to use 65 CacheLoadModifier _LOAD_MODIFIER, ///< Cache load modifier for reading input elements 66 BlockScanAlgorithm _SCAN_ALGORITHM> ///< The BlockScan algorithm to use 67 struct AgentSegmentFixupPolicy 68 { 69 enum 70 { 71 BLOCK_THREADS = _BLOCK_THREADS, ///< Threads per thread block 72 ITEMS_PER_THREAD = _ITEMS_PER_THREAD, ///< Items per thread (per tile of input) 73 }; 74 75 static const BlockLoadAlgorithm LOAD_ALGORITHM = _LOAD_ALGORITHM; ///< The BlockLoad algorithm to use 76 static const CacheLoadModifier LOAD_MODIFIER = _LOAD_MODIFIER; ///< Cache load modifier for reading input elements 77 static const BlockScanAlgorithm SCAN_ALGORITHM = _SCAN_ALGORITHM; ///< The BlockScan algorithm to use 78 }; 79 80 81 /****************************************************************************** 82 * Thread block abstractions 83 ******************************************************************************/ 84 85 /** 86 * \brief AgentSegmentFixup implements a stateful abstraction of CUDA thread blocks for participating in device-wide reduce-value-by-key 87 */ 88 template < 89 typename AgentSegmentFixupPolicyT, ///< Parameterized AgentSegmentFixupPolicy tuning policy type 90 typename PairsInputIteratorT, ///< Random-access input iterator type for keys 91 typename AggregatesOutputIteratorT, ///< Random-access output iterator type for values 92 typename EqualityOpT, ///< KeyT equality operator type 93 typename ReductionOpT, ///< ValueT reduction operator type 94 typename OffsetT> ///< Signed integer type for global offsets 95 struct AgentSegmentFixup 96 { 97 //--------------------------------------------------------------------- 98 // Types and constants 99 //--------------------------------------------------------------------- 100 101 // Data type of key-value input iterator 102 typedef typename std::iterator_traits<PairsInputIteratorT>::value_type KeyValuePairT; 103 104 // Value type 105 typedef typename KeyValuePairT::Value ValueT; 106 107 // Tile status descriptor interface type 108 typedef ReduceByKeyScanTileState<ValueT, OffsetT> ScanTileStateT; 109 110 // Constants 111 enum 112 { 113 BLOCK_THREADS = AgentSegmentFixupPolicyT::BLOCK_THREADS, 114 ITEMS_PER_THREAD = AgentSegmentFixupPolicyT::ITEMS_PER_THREAD, 115 TILE_ITEMS = BLOCK_THREADS * ITEMS_PER_THREAD, 116 117 // Whether or not do fixup using RLE + global atomics 118 USE_ATOMIC_FIXUP = (CUB_PTX_ARCH >= 350) && 119 (Equals<ValueT, float>::VALUE || 120 Equals<ValueT, int>::VALUE || 121 Equals<ValueT, unsigned int>::VALUE || 122 Equals<ValueT, unsigned long long>::VALUE), 123 124 // Whether or not the scan operation has a zero-valued identity value (true if we're performing addition on a primitive type) 125 HAS_IDENTITY_ZERO = (Equals<ReductionOpT, cub::Sum>::VALUE) && (Traits<ValueT>::PRIMITIVE), 126 }; 127 128 // Cache-modified Input iterator wrapper type (for applying cache modifier) for keys 129 typedef typename If<IsPointer<PairsInputIteratorT>::VALUE, 130 CacheModifiedInputIterator<AgentSegmentFixupPolicyT::LOAD_MODIFIER, KeyValuePairT, OffsetT>, // Wrap the native input pointer with CacheModifiedValuesInputIterator 131 PairsInputIteratorT>::Type // Directly use the supplied input iterator type 132 WrappedPairsInputIteratorT; 133 134 // Cache-modified Input iterator wrapper type (for applying cache modifier) for fixup values 135 typedef typename If<IsPointer<AggregatesOutputIteratorT>::VALUE, 136 CacheModifiedInputIterator<AgentSegmentFixupPolicyT::LOAD_MODIFIER, ValueT, OffsetT>, // Wrap the native input pointer with CacheModifiedValuesInputIterator 137 AggregatesOutputIteratorT>::Type // Directly use the supplied input iterator type 138 WrappedFixupInputIteratorT; 139 140 // Reduce-value-by-segment scan operator 141 typedef ReduceByKeyOp<cub::Sum> ReduceBySegmentOpT; 142 143 // Parameterized BlockLoad type for pairs 144 typedef BlockLoad< 145 KeyValuePairT, 146 BLOCK_THREADS, 147 ITEMS_PER_THREAD, 148 AgentSegmentFixupPolicyT::LOAD_ALGORITHM> 149 BlockLoadPairs; 150 151 // Parameterized BlockScan type 152 typedef BlockScan< 153 KeyValuePairT, 154 BLOCK_THREADS, 155 AgentSegmentFixupPolicyT::SCAN_ALGORITHM> 156 BlockScanT; 157 158 // Callback type for obtaining tile prefix during block scan 159 typedef TilePrefixCallbackOp< 160 KeyValuePairT, 161 ReduceBySegmentOpT, 162 ScanTileStateT> 163 TilePrefixCallbackOpT; 164 165 // Shared memory type for this thread block 166 union _TempStorage 167 { 168 struct 169 { 170 typename BlockScanT::TempStorage scan; // Smem needed for tile scanning 171 typename TilePrefixCallbackOpT::TempStorage prefix; // Smem needed for cooperative prefix callback 172 }; 173 174 // Smem needed for loading keys 175 typename BlockLoadPairs::TempStorage load_pairs; 176 }; 177 178 // Alias wrapper allowing storage to be unioned 179 struct TempStorage : Uninitialized<_TempStorage> {}; 180 181 182 //--------------------------------------------------------------------- 183 // Per-thread fields 184 //--------------------------------------------------------------------- 185 186 _TempStorage& temp_storage; ///< Reference to temp_storage 187 WrappedPairsInputIteratorT d_pairs_in; ///< Input keys 188 AggregatesOutputIteratorT d_aggregates_out; ///< Output value aggregates 189 WrappedFixupInputIteratorT d_fixup_in; ///< Fixup input values 190 InequalityWrapper<EqualityOpT> inequality_op; ///< KeyT inequality operator 191 ReductionOpT reduction_op; ///< Reduction operator 192 ReduceBySegmentOpT scan_op; ///< Reduce-by-segment scan operator 193 194 195 //--------------------------------------------------------------------- 196 // Constructor 197 //--------------------------------------------------------------------- 198 199 // Constructor 200 __device__ __forceinline__ AgentSegmentFixupcub::AgentSegmentFixup201 AgentSegmentFixup( 202 TempStorage& temp_storage, ///< Reference to temp_storage 203 PairsInputIteratorT d_pairs_in, ///< Input keys 204 AggregatesOutputIteratorT d_aggregates_out, ///< Output value aggregates 205 EqualityOpT equality_op, ///< KeyT equality operator 206 ReductionOpT reduction_op) ///< ValueT reduction operator 207 : 208 temp_storage(temp_storage.Alias()), 209 d_pairs_in(d_pairs_in), 210 d_aggregates_out(d_aggregates_out), 211 d_fixup_in(d_aggregates_out), 212 inequality_op(equality_op), 213 reduction_op(reduction_op), 214 scan_op(reduction_op) 215 {} 216 217 218 //--------------------------------------------------------------------- 219 // Cooperatively scan a device-wide sequence of tiles with other CTAs 220 //--------------------------------------------------------------------- 221 222 223 /** 224 * Process input tile. Specialized for atomic-fixup 225 */ 226 template <bool IS_LAST_TILE> ConsumeTilecub::AgentSegmentFixup227 __device__ __forceinline__ void ConsumeTile( 228 OffsetT num_remaining, ///< Number of global input items remaining (including this tile) 229 int tile_idx, ///< Tile index 230 OffsetT tile_offset, ///< Tile offset 231 ScanTileStateT& tile_state, ///< Global tile state descriptor 232 Int2Type<true> use_atomic_fixup) ///< Marker whether to use atomicAdd (instead of reduce-by-key) 233 { 234 KeyValuePairT pairs[ITEMS_PER_THREAD]; 235 236 // Load pairs 237 KeyValuePairT oob_pair; 238 oob_pair.key = -1; 239 240 if (IS_LAST_TILE) 241 BlockLoadPairs(temp_storage.load_pairs).Load(d_pairs_in + tile_offset, pairs, num_remaining, oob_pair); 242 else 243 BlockLoadPairs(temp_storage.load_pairs).Load(d_pairs_in + tile_offset, pairs); 244 245 // RLE 246 #pragma unroll 247 for (int ITEM = 1; ITEM < ITEMS_PER_THREAD; ++ITEM) 248 { 249 ValueT* d_scatter = d_aggregates_out + pairs[ITEM - 1].key; 250 if (pairs[ITEM].key != pairs[ITEM - 1].key) 251 atomicAdd(d_scatter, pairs[ITEM - 1].value); 252 else 253 pairs[ITEM].value = reduction_op(pairs[ITEM - 1].value, pairs[ITEM].value); 254 } 255 256 // Flush last item if valid 257 ValueT* d_scatter = d_aggregates_out + pairs[ITEMS_PER_THREAD - 1].key; 258 if ((!IS_LAST_TILE) || (pairs[ITEMS_PER_THREAD - 1].key >= 0)) 259 atomicAdd(d_scatter, pairs[ITEMS_PER_THREAD - 1].value); 260 } 261 262 263 /** 264 * Process input tile. Specialized for reduce-by-key fixup 265 */ 266 template <bool IS_LAST_TILE> ConsumeTilecub::AgentSegmentFixup267 __device__ __forceinline__ void ConsumeTile( 268 OffsetT num_remaining, ///< Number of global input items remaining (including this tile) 269 int tile_idx, ///< Tile index 270 OffsetT tile_offset, ///< Tile offset 271 ScanTileStateT& tile_state, ///< Global tile state descriptor 272 Int2Type<false> use_atomic_fixup) ///< Marker whether to use atomicAdd (instead of reduce-by-key) 273 { 274 KeyValuePairT pairs[ITEMS_PER_THREAD]; 275 KeyValuePairT scatter_pairs[ITEMS_PER_THREAD]; 276 277 // Load pairs 278 KeyValuePairT oob_pair; 279 oob_pair.key = -1; 280 281 if (IS_LAST_TILE) 282 BlockLoadPairs(temp_storage.load_pairs).Load(d_pairs_in + tile_offset, pairs, num_remaining, oob_pair); 283 else 284 BlockLoadPairs(temp_storage.load_pairs).Load(d_pairs_in + tile_offset, pairs); 285 286 CTA_SYNC(); 287 288 KeyValuePairT tile_aggregate; 289 if (tile_idx == 0) 290 { 291 // Exclusive scan of values and segment_flags 292 BlockScanT(temp_storage.scan).ExclusiveScan(pairs, scatter_pairs, scan_op, tile_aggregate); 293 294 // Update tile status if this is not the last tile 295 if (threadIdx.x == 0) 296 { 297 // Set first segment id to not trigger a flush (invalid from exclusive scan) 298 scatter_pairs[0].key = pairs[0].key; 299 300 if (!IS_LAST_TILE) 301 tile_state.SetInclusive(0, tile_aggregate); 302 303 } 304 } 305 else 306 { 307 // Exclusive scan of values and segment_flags 308 TilePrefixCallbackOpT prefix_op(tile_state, temp_storage.prefix, scan_op, tile_idx); 309 BlockScanT(temp_storage.scan).ExclusiveScan(pairs, scatter_pairs, scan_op, prefix_op); 310 tile_aggregate = prefix_op.GetBlockAggregate(); 311 } 312 313 // Scatter updated values 314 #pragma unroll 315 for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) 316 { 317 if (scatter_pairs[ITEM].key != pairs[ITEM].key) 318 { 319 // Update the value at the key location 320 ValueT value = d_fixup_in[scatter_pairs[ITEM].key]; 321 value = reduction_op(value, scatter_pairs[ITEM].value); 322 323 d_aggregates_out[scatter_pairs[ITEM].key] = value; 324 } 325 } 326 327 // Finalize the last item 328 if (IS_LAST_TILE) 329 { 330 // Last thread will output final count and last item, if necessary 331 if (threadIdx.x == BLOCK_THREADS - 1) 332 { 333 // If the last tile is a whole tile, the inclusive prefix contains accumulated value reduction for the last segment 334 if (num_remaining == TILE_ITEMS) 335 { 336 // Update the value at the key location 337 OffsetT last_key = pairs[ITEMS_PER_THREAD - 1].key; 338 d_aggregates_out[last_key] = reduction_op(tile_aggregate.value, d_fixup_in[last_key]); 339 } 340 } 341 } 342 } 343 344 345 /** 346 * Scan tiles of items as part of a dynamic chained scan 347 */ ConsumeRangecub::AgentSegmentFixup348 __device__ __forceinline__ void ConsumeRange( 349 int num_items, ///< Total number of input items 350 int num_tiles, ///< Total number of input tiles 351 ScanTileStateT& tile_state) ///< Global tile state descriptor 352 { 353 // Blocks are launched in increasing order, so just assign one tile per block 354 int tile_idx = (blockIdx.x * gridDim.y) + blockIdx.y; // Current tile index 355 OffsetT tile_offset = tile_idx * TILE_ITEMS; // Global offset for the current tile 356 OffsetT num_remaining = num_items - tile_offset; // Remaining items (including this tile) 357 358 if (num_remaining > TILE_ITEMS) 359 { 360 // Not the last tile (full) 361 ConsumeTile<false>(num_remaining, tile_idx, tile_offset, tile_state, Int2Type<USE_ATOMIC_FIXUP>()); 362 } 363 else if (num_remaining > 0) 364 { 365 // The last tile (possibly partially-full) 366 ConsumeTile<true>(num_remaining, tile_idx, tile_offset, tile_state, Int2Type<USE_ATOMIC_FIXUP>()); 367 } 368 } 369 370 }; 371 372 373 } // CUB namespace 374 CUB_NS_POSTFIX // Optional outer namespace(s) 375 376