1 /******************************************************************************
2  * Copyright (c) 2011, Duane Merrill.  All rights reserved.
3  * Copyright (c) 2011-2018, NVIDIA CORPORATION.  All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are met:
7  *     * Redistributions of source code must retain the above copyright
8  *       notice, this list of conditions and the following disclaimer.
9  *     * Redistributions in binary form must reproduce the above copyright
10  *       notice, this list of conditions and the following disclaimer in the
11  *       documentation and/or other materials provided with the distribution.
12  *     * Neither the name of the NVIDIA CORPORATION nor the
13  *       names of its contributors may be used to endorse or promote products
14  *       derived from this software without specific prior written permission.
15  *
16  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
17  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19  * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
20  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26  *
27  ******************************************************************************/
28 
29 /**
30  * \file
31  * Simple binary operator functor types
32  */
33 
34 /******************************************************************************
35  * Simple functor operators
36  ******************************************************************************/
37 
38 #pragma once
39 
40 #include "../util_macro.cuh"
41 #include "../util_type.cuh"
42 #include "../util_namespace.cuh"
43 
44 /// Optional outer namespace(s)
45 CUB_NS_PREFIX
46 
47 /// CUB namespace
48 namespace cub {
49 
50 
51 /**
52  * \addtogroup UtilModule
53  * @{
54  */
55 
56 /**
57  * \brief Default equality functor
58  */
59 struct Equality
60 {
61     /// Boolean equality operator, returns <tt>(a == b)</tt>
62     template <typename T>
operator ()cub::Equality63     __host__ __device__ __forceinline__ bool operator()(const T &a, const T &b) const
64     {
65         return a == b;
66     }
67 };
68 
69 
70 /**
71  * \brief Default inequality functor
72  */
73 struct Inequality
74 {
75     /// Boolean inequality operator, returns <tt>(a != b)</tt>
76     template <typename T>
operator ()cub::Inequality77     __host__ __device__ __forceinline__ bool operator()(const T &a, const T &b) const
78     {
79         return a != b;
80     }
81 };
82 
83 
84 /**
85  * \brief Inequality functor (wraps equality functor)
86  */
87 template <typename EqualityOp>
88 struct InequalityWrapper
89 {
90     /// Wrapped equality operator
91     EqualityOp op;
92 
93     /// Constructor
94     __host__ __device__ __forceinline__
InequalityWrappercub::InequalityWrapper95     InequalityWrapper(EqualityOp op) : op(op) {}
96 
97     /// Boolean inequality operator, returns <tt>(a != b)</tt>
98     template <typename T>
operator ()cub::InequalityWrapper99     __host__ __device__ __forceinline__ bool operator()(const T &a, const T &b)
100     {
101         return !op(a, b);
102     }
103 };
104 
105 
106 /**
107  * \brief Default sum functor
108  */
109 struct Sum
110 {
111     /// Boolean sum operator, returns <tt>a + b</tt>
112     template <typename T>
operator ()cub::Sum113     __host__ __device__ __forceinline__ T operator()(const T &a, const T &b) const
114     {
115         return a + b;
116     }
117 };
118 
119 
120 /**
121  * \brief Default max functor
122  */
123 struct Max
124 {
125     /// Boolean max operator, returns <tt>(a > b) ? a : b</tt>
126     template <typename T>
operator ()cub::Max127     __host__ __device__ __forceinline__ T operator()(const T &a, const T &b) const
128     {
129         return CUB_MAX(a, b);
130     }
131 };
132 
133 
134 /**
135  * \brief Arg max functor (keeps the value and offset of the first occurrence of the larger item)
136  */
137 struct ArgMax
138 {
139     /// Boolean max operator, preferring the item having the smaller offset in case of ties
140     template <typename T, typename OffsetT>
operator ()cub::ArgMax141     __host__ __device__ __forceinline__ KeyValuePair<OffsetT, T> operator()(
142         const KeyValuePair<OffsetT, T> &a,
143         const KeyValuePair<OffsetT, T> &b) const
144     {
145 // Mooch BUG (device reduce argmax gk110 3.2 million random fp32)
146 //        return ((b.value > a.value) || ((a.value == b.value) && (b.key < a.key))) ? b : a;
147 
148         if ((b.value > a.value) || ((a.value == b.value) && (b.key < a.key)))
149             return b;
150         return a;
151     }
152 };
153 
154 
155 /**
156  * \brief Default min functor
157  */
158 struct Min
159 {
160     /// Boolean min operator, returns <tt>(a < b) ? a : b</tt>
161     template <typename T>
operator ()cub::Min162     __host__ __device__ __forceinline__ T operator()(const T &a, const T &b) const
163     {
164         return CUB_MIN(a, b);
165     }
166 };
167 
168 
169 /**
170  * \brief Arg min functor (keeps the value and offset of the first occurrence of the smallest item)
171  */
172 struct ArgMin
173 {
174     /// Boolean min operator, preferring the item having the smaller offset in case of ties
175     template <typename T, typename OffsetT>
operator ()cub::ArgMin176     __host__ __device__ __forceinline__ KeyValuePair<OffsetT, T> operator()(
177         const KeyValuePair<OffsetT, T> &a,
178         const KeyValuePair<OffsetT, T> &b) const
179     {
180 // Mooch BUG (device reduce argmax gk110 3.2 million random fp32)
181 //        return ((b.value < a.value) || ((a.value == b.value) && (b.key < a.key))) ? b : a;
182 
183         if ((b.value < a.value) || ((a.value == b.value) && (b.key < a.key)))
184             return b;
185         return a;
186     }
187 };
188 
189 
190 /**
191  * \brief Default cast functor
192  */
193 template <typename B>
194 struct CastOp
195 {
196     /// Cast operator, returns <tt>(B) a</tt>
197     template <typename A>
operator ()cub::CastOp198     __host__ __device__ __forceinline__ B operator()(const A &a) const
199     {
200         return (B) a;
201     }
202 };
203 
204 
205 /**
206  * \brief Binary operator wrapper for switching non-commutative scan arguments
207  */
208 template <typename ScanOp>
209 class SwizzleScanOp
210 {
211 private:
212 
213     /// Wrapped scan operator
214     ScanOp scan_op;
215 
216 public:
217 
218     /// Constructor
219     __host__ __device__ __forceinline__
SwizzleScanOp(ScanOp scan_op)220     SwizzleScanOp(ScanOp scan_op) : scan_op(scan_op) {}
221 
222     /// Switch the scan arguments
223     template <typename T>
224     __host__ __device__ __forceinline__
operator ()(const T & a,const T & b)225     T operator()(const T &a, const T &b)
226     {
227       T _a(a);
228       T _b(b);
229 
230       return scan_op(_b, _a);
231     }
232 };
233 
234 
235 /**
236  * \brief Reduce-by-segment functor.
237  *
238  * Given two cub::KeyValuePair inputs \p a and \p b and a
239  * binary associative combining operator \p <tt>f(const T &x, const T &y)</tt>,
240  * an instance of this functor returns a cub::KeyValuePair whose \p key
241  * field is <tt>a.key</tt> + <tt>b.key</tt>, and whose \p value field
242  * is either b.value if b.key is non-zero, or f(a.value, b.value) otherwise.
243  *
244  * ReduceBySegmentOp is an associative, non-commutative binary combining operator
245  * for input sequences of cub::KeyValuePair pairings.  Such
246  * sequences are typically used to represent a segmented set of values to be reduced
247  * and a corresponding set of {0,1}-valued integer "head flags" demarcating the
248  * first value of each segment.
249  *
250  */
251 template <typename ReductionOpT>    ///< Binary reduction operator to apply to values
252 struct ReduceBySegmentOp
253 {
254     /// Wrapped reduction operator
255     ReductionOpT op;
256 
257     /// Constructor
ReduceBySegmentOpcub::ReduceBySegmentOp258     __host__ __device__ __forceinline__ ReduceBySegmentOp() {}
259 
260     /// Constructor
ReduceBySegmentOpcub::ReduceBySegmentOp261     __host__ __device__ __forceinline__ ReduceBySegmentOp(ReductionOpT op) : op(op) {}
262 
263     /// Scan operator
264     template <typename KeyValuePairT>       ///< KeyValuePair pairing of T (value) and OffsetT (head flag)
operator ()cub::ReduceBySegmentOp265     __host__ __device__ __forceinline__ KeyValuePairT operator()(
266         const KeyValuePairT &first,         ///< First partial reduction
267         const KeyValuePairT &second)        ///< Second partial reduction
268     {
269         KeyValuePairT retval;
270         retval.key = first.key + second.key;
271         retval.value = (second.key) ?
272                 second.value :                          // The second partial reduction spans a segment reset, so it's value aggregate becomes the running aggregate
273                 op(first.value, second.value);          // The second partial reduction does not span a reset, so accumulate both into the running aggregate
274         return retval;
275     }
276 };
277 
278 
279 
280 template <typename ReductionOpT>    ///< Binary reduction operator to apply to values
281 struct ReduceByKeyOp
282 {
283     /// Wrapped reduction operator
284     ReductionOpT op;
285 
286     /// Constructor
ReduceByKeyOpcub::ReduceByKeyOp287     __host__ __device__ __forceinline__ ReduceByKeyOp() {}
288 
289     /// Constructor
ReduceByKeyOpcub::ReduceByKeyOp290     __host__ __device__ __forceinline__ ReduceByKeyOp(ReductionOpT op) : op(op) {}
291 
292     /// Scan operator
293     template <typename KeyValuePairT>
operator ()cub::ReduceByKeyOp294     __host__ __device__ __forceinline__ KeyValuePairT operator()(
295         const KeyValuePairT &first,       ///< First partial reduction
296         const KeyValuePairT &second)      ///< Second partial reduction
297     {
298         KeyValuePairT retval = second;
299 
300         if (first.key == second.key)
301             retval.value = op(first.value, retval.value);
302 
303         return retval;
304     }
305 };
306 
307 
308 
309 
310 
311 
312 
313 /** @} */       // end group UtilModule
314 
315 
316 }               // CUB namespace
317 CUB_NS_POSTFIX  // Optional outer namespace(s)
318