1 /****************************************************************************** 2 * Copyright (c) 2011, Duane Merrill. All rights reserved. 3 * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. 4 * 5 * Redistribution and use in source and binary forms, with or without 6 * modification, are permitted provided that the following conditions are met: 7 * * Redistributions of source code must retain the above copyright 8 * notice, this list of conditions and the following disclaimer. 9 * * Redistributions in binary form must reproduce the above copyright 10 * notice, this list of conditions and the following disclaimer in the 11 * documentation and/or other materials provided with the distribution. 12 * * Neither the name of the NVIDIA CORPORATION nor the 13 * names of its contributors may be used to endorse or promote products 14 * derived from this software without specific prior written permission. 15 * 16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 17 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 18 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 20 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 21 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 22 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 23 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 25 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 * 27 ******************************************************************************/ 28 29 /** 30 * \file 31 * Simple binary operator functor types 32 */ 33 34 /****************************************************************************** 35 * Simple functor operators 36 ******************************************************************************/ 37 38 #pragma once 39 40 #include "../util_macro.cuh" 41 #include "../util_type.cuh" 42 #include "../util_namespace.cuh" 43 44 /// Optional outer namespace(s) 45 CUB_NS_PREFIX 46 47 /// CUB namespace 48 namespace cub { 49 50 51 /** 52 * \addtogroup UtilModule 53 * @{ 54 */ 55 56 /** 57 * \brief Default equality functor 58 */ 59 struct Equality 60 { 61 /// Boolean equality operator, returns <tt>(a == b)</tt> 62 template <typename T> operator ()cub::Equality63 __host__ __device__ __forceinline__ bool operator()(const T &a, const T &b) const 64 { 65 return a == b; 66 } 67 }; 68 69 70 /** 71 * \brief Default inequality functor 72 */ 73 struct Inequality 74 { 75 /// Boolean inequality operator, returns <tt>(a != b)</tt> 76 template <typename T> operator ()cub::Inequality77 __host__ __device__ __forceinline__ bool operator()(const T &a, const T &b) const 78 { 79 return a != b; 80 } 81 }; 82 83 84 /** 85 * \brief Inequality functor (wraps equality functor) 86 */ 87 template <typename EqualityOp> 88 struct InequalityWrapper 89 { 90 /// Wrapped equality operator 91 EqualityOp op; 92 93 /// Constructor 94 __host__ __device__ __forceinline__ InequalityWrappercub::InequalityWrapper95 InequalityWrapper(EqualityOp op) : op(op) {} 96 97 /// Boolean inequality operator, returns <tt>(a != b)</tt> 98 template <typename T> operator ()cub::InequalityWrapper99 __host__ __device__ __forceinline__ bool operator()(const T &a, const T &b) 100 { 101 return !op(a, b); 102 } 103 }; 104 105 106 /** 107 * \brief Default sum functor 108 */ 109 struct Sum 110 { 111 /// Boolean sum operator, returns <tt>a + b</tt> 112 template <typename T> operator ()cub::Sum113 __host__ __device__ __forceinline__ T operator()(const T &a, const T &b) const 114 { 115 return a + b; 116 } 117 }; 118 119 120 /** 121 * \brief Default max functor 122 */ 123 struct Max 124 { 125 /// Boolean max operator, returns <tt>(a > b) ? a : b</tt> 126 template <typename T> operator ()cub::Max127 __host__ __device__ __forceinline__ T operator()(const T &a, const T &b) const 128 { 129 return CUB_MAX(a, b); 130 } 131 }; 132 133 134 /** 135 * \brief Arg max functor (keeps the value and offset of the first occurrence of the larger item) 136 */ 137 struct ArgMax 138 { 139 /// Boolean max operator, preferring the item having the smaller offset in case of ties 140 template <typename T, typename OffsetT> operator ()cub::ArgMax141 __host__ __device__ __forceinline__ KeyValuePair<OffsetT, T> operator()( 142 const KeyValuePair<OffsetT, T> &a, 143 const KeyValuePair<OffsetT, T> &b) const 144 { 145 // Mooch BUG (device reduce argmax gk110 3.2 million random fp32) 146 // return ((b.value > a.value) || ((a.value == b.value) && (b.key < a.key))) ? b : a; 147 148 if ((b.value > a.value) || ((a.value == b.value) && (b.key < a.key))) 149 return b; 150 return a; 151 } 152 }; 153 154 155 /** 156 * \brief Default min functor 157 */ 158 struct Min 159 { 160 /// Boolean min operator, returns <tt>(a < b) ? a : b</tt> 161 template <typename T> operator ()cub::Min162 __host__ __device__ __forceinline__ T operator()(const T &a, const T &b) const 163 { 164 return CUB_MIN(a, b); 165 } 166 }; 167 168 169 /** 170 * \brief Arg min functor (keeps the value and offset of the first occurrence of the smallest item) 171 */ 172 struct ArgMin 173 { 174 /// Boolean min operator, preferring the item having the smaller offset in case of ties 175 template <typename T, typename OffsetT> operator ()cub::ArgMin176 __host__ __device__ __forceinline__ KeyValuePair<OffsetT, T> operator()( 177 const KeyValuePair<OffsetT, T> &a, 178 const KeyValuePair<OffsetT, T> &b) const 179 { 180 // Mooch BUG (device reduce argmax gk110 3.2 million random fp32) 181 // return ((b.value < a.value) || ((a.value == b.value) && (b.key < a.key))) ? b : a; 182 183 if ((b.value < a.value) || ((a.value == b.value) && (b.key < a.key))) 184 return b; 185 return a; 186 } 187 }; 188 189 190 /** 191 * \brief Default cast functor 192 */ 193 template <typename B> 194 struct CastOp 195 { 196 /// Cast operator, returns <tt>(B) a</tt> 197 template <typename A> operator ()cub::CastOp198 __host__ __device__ __forceinline__ B operator()(const A &a) const 199 { 200 return (B) a; 201 } 202 }; 203 204 205 /** 206 * \brief Binary operator wrapper for switching non-commutative scan arguments 207 */ 208 template <typename ScanOp> 209 class SwizzleScanOp 210 { 211 private: 212 213 /// Wrapped scan operator 214 ScanOp scan_op; 215 216 public: 217 218 /// Constructor 219 __host__ __device__ __forceinline__ SwizzleScanOp(ScanOp scan_op)220 SwizzleScanOp(ScanOp scan_op) : scan_op(scan_op) {} 221 222 /// Switch the scan arguments 223 template <typename T> 224 __host__ __device__ __forceinline__ operator ()(const T & a,const T & b)225 T operator()(const T &a, const T &b) 226 { 227 T _a(a); 228 T _b(b); 229 230 return scan_op(_b, _a); 231 } 232 }; 233 234 235 /** 236 * \brief Reduce-by-segment functor. 237 * 238 * Given two cub::KeyValuePair inputs \p a and \p b and a 239 * binary associative combining operator \p <tt>f(const T &x, const T &y)</tt>, 240 * an instance of this functor returns a cub::KeyValuePair whose \p key 241 * field is <tt>a.key</tt> + <tt>b.key</tt>, and whose \p value field 242 * is either b.value if b.key is non-zero, or f(a.value, b.value) otherwise. 243 * 244 * ReduceBySegmentOp is an associative, non-commutative binary combining operator 245 * for input sequences of cub::KeyValuePair pairings. Such 246 * sequences are typically used to represent a segmented set of values to be reduced 247 * and a corresponding set of {0,1}-valued integer "head flags" demarcating the 248 * first value of each segment. 249 * 250 */ 251 template <typename ReductionOpT> ///< Binary reduction operator to apply to values 252 struct ReduceBySegmentOp 253 { 254 /// Wrapped reduction operator 255 ReductionOpT op; 256 257 /// Constructor ReduceBySegmentOpcub::ReduceBySegmentOp258 __host__ __device__ __forceinline__ ReduceBySegmentOp() {} 259 260 /// Constructor ReduceBySegmentOpcub::ReduceBySegmentOp261 __host__ __device__ __forceinline__ ReduceBySegmentOp(ReductionOpT op) : op(op) {} 262 263 /// Scan operator 264 template <typename KeyValuePairT> ///< KeyValuePair pairing of T (value) and OffsetT (head flag) operator ()cub::ReduceBySegmentOp265 __host__ __device__ __forceinline__ KeyValuePairT operator()( 266 const KeyValuePairT &first, ///< First partial reduction 267 const KeyValuePairT &second) ///< Second partial reduction 268 { 269 KeyValuePairT retval; 270 retval.key = first.key + second.key; 271 retval.value = (second.key) ? 272 second.value : // The second partial reduction spans a segment reset, so it's value aggregate becomes the running aggregate 273 op(first.value, second.value); // The second partial reduction does not span a reset, so accumulate both into the running aggregate 274 return retval; 275 } 276 }; 277 278 279 280 template <typename ReductionOpT> ///< Binary reduction operator to apply to values 281 struct ReduceByKeyOp 282 { 283 /// Wrapped reduction operator 284 ReductionOpT op; 285 286 /// Constructor ReduceByKeyOpcub::ReduceByKeyOp287 __host__ __device__ __forceinline__ ReduceByKeyOp() {} 288 289 /// Constructor ReduceByKeyOpcub::ReduceByKeyOp290 __host__ __device__ __forceinline__ ReduceByKeyOp(ReductionOpT op) : op(op) {} 291 292 /// Scan operator 293 template <typename KeyValuePairT> operator ()cub::ReduceByKeyOp294 __host__ __device__ __forceinline__ KeyValuePairT operator()( 295 const KeyValuePairT &first, ///< First partial reduction 296 const KeyValuePairT &second) ///< Second partial reduction 297 { 298 KeyValuePairT retval = second; 299 300 if (first.key == second.key) 301 retval.value = op(first.value, retval.value); 302 303 return retval; 304 } 305 }; 306 307 308 309 310 311 312 313 /** @} */ // end group UtilModule 314 315 316 } // CUB namespace 317 CUB_NS_POSTFIX // Optional outer namespace(s) 318