1/* 2 * Copyright 2008-2013 NVIDIA Corporation 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17#include <thrust/iterator/iterator_traits.h> 18#include <thrust/detail/temporary_array.h> 19#include <thrust/system/tbb/detail/execution_policy.h> 20#include <thrust/merge.h> 21#include <thrust/binary_search.h> 22#include <thrust/detail/seq.h> 23#include <tbb/parallel_for.h> 24 25namespace thrust 26{ 27namespace system 28{ 29namespace tbb 30{ 31namespace detail 32{ 33namespace merge_detail 34{ 35 36template<typename InputIterator1, 37 typename InputIterator2, 38 typename OutputIterator, 39 typename StrictWeakOrdering> 40struct range 41{ 42 InputIterator1 first1, last1; 43 InputIterator2 first2, last2; 44 OutputIterator result; 45 StrictWeakOrdering comp; 46 size_t grain_size; 47 48 range(InputIterator1 first1, InputIterator1 last1, 49 InputIterator2 first2, InputIterator2 last2, 50 OutputIterator result, 51 StrictWeakOrdering comp, 52 size_t grain_size = 1024) 53 : first1(first1), last1(last1), 54 first2(first2), last2(last2), 55 result(result), comp(comp), grain_size(grain_size) 56 {} 57 58 range(range& r, ::tbb::split) 59 : first1(r.first1), last1(r.last1), 60 first2(r.first2), last2(r.last2), 61 result(r.result), comp(r.comp), grain_size(r.grain_size) 62 { 63 // we can assume n1 and n2 are not both 0 64 size_t n1 = thrust::distance(first1, last1); 65 size_t n2 = thrust::distance(first2, last2); 66 67 InputIterator1 mid1 = first1; 68 InputIterator2 mid2 = first2; 69 70 if (n1 > n2) 71 { 72 mid1 += n1 / 2; 73 mid2 = thrust::lower_bound(thrust::seq, first2, last2, raw_reference_cast(*mid1), comp); 74 } 75 else 76 { 77 mid2 += n2 / 2; 78 mid1 = thrust::upper_bound(thrust::seq, first1, last1, raw_reference_cast(*mid2), comp); 79 } 80 81 // set first range to [first1, mid1), [first2, mid2), result 82 r.last1 = mid1; 83 r.last2 = mid2; 84 85 // set second range to [mid1, last1), [mid2, last2), result + (mid1 - first1) + (mid2 - first2) 86 first1 = mid1; 87 first2 = mid2; 88 result += thrust::distance(r.first1, mid1) + thrust::distance(r.first2, mid2); 89 } 90 91 bool empty(void) const 92 { 93 return (first1 == last1) && (first2 == last2); 94 } 95 96 bool is_divisible(void) const 97 { 98 return static_cast<size_t>(thrust::distance(first1, last1) + thrust::distance(first2, last2)) > grain_size; 99 } 100}; 101 102struct body 103{ 104 template <typename Range> 105 void operator()(Range& r) const 106 { 107 thrust::merge(thrust::seq, 108 r.first1, r.last1, 109 r.first2, r.last2, 110 r.result, 111 r.comp); 112 } 113}; 114 115} // end namespace merge_detail 116 117namespace merge_by_key_detail 118{ 119 120template<typename InputIterator1, 121 typename InputIterator2, 122 typename InputIterator3, 123 typename InputIterator4, 124 typename OutputIterator1, 125 typename OutputIterator2, 126 typename StrictWeakOrdering> 127struct range 128{ 129 InputIterator1 keys_first1, keys_last1; 130 InputIterator2 keys_first2, keys_last2; 131 InputIterator3 values_first1; 132 InputIterator4 values_first2; 133 OutputIterator1 keys_result; 134 OutputIterator2 values_result; 135 StrictWeakOrdering comp; 136 size_t grain_size; 137 138 range(InputIterator1 keys_first1, InputIterator1 keys_last1, 139 InputIterator2 keys_first2, InputIterator2 keys_last2, 140 InputIterator3 values_first1, 141 InputIterator4 values_first2, 142 OutputIterator1 keys_result, 143 OutputIterator2 values_result, 144 StrictWeakOrdering comp, 145 size_t grain_size = 1024) 146 : keys_first1(keys_first1), keys_last1(keys_last1), 147 keys_first2(keys_first2), keys_last2(keys_last2), 148 values_first1(values_first1), 149 values_first2(values_first2), 150 keys_result(keys_result), values_result(values_result), 151 comp(comp), grain_size(grain_size) 152 {} 153 154 range(range& r, ::tbb::split) 155 : keys_first1(r.keys_first1), keys_last1(r.keys_last1), 156 keys_first2(r.keys_first2), keys_last2(r.keys_last2), 157 values_first1(r.values_first1), 158 values_first2(r.values_first2), 159 keys_result(r.keys_result), values_result(r.values_result), 160 comp(r.comp), grain_size(r.grain_size) 161 { 162 // we can assume n1 and n2 are not both 0 163 size_t n1 = thrust::distance(keys_first1, keys_last1); 164 size_t n2 = thrust::distance(keys_first2, keys_last2); 165 166 InputIterator1 mid1 = keys_first1; 167 InputIterator2 mid2 = keys_first2; 168 169 if (n1 > n2) 170 { 171 mid1 += n1 / 2; 172 mid2 = thrust::lower_bound(thrust::seq, keys_first2, keys_last2, raw_reference_cast(*mid1), comp); 173 } 174 else 175 { 176 mid2 += n2 / 2; 177 mid1 = thrust::upper_bound(thrust::seq, keys_first1, keys_last1, raw_reference_cast(*mid2), comp); 178 } 179 180 // set first range to [keys_first1, mid1), [keys_first2, mid2), keys_result, values_result 181 r.keys_last1 = mid1; 182 r.keys_last2 = mid2; 183 184 // set second range to [mid1, keys_last1), [mid2, keys_last2), keys_result + (mid1 - keys_first1) + (mid2 - keys_first2), values_result + (mid1 - keys_first1) + (mid2 - keys_first2) 185 keys_first1 = mid1; 186 keys_first2 = mid2; 187 values_first1 += thrust::distance(r.keys_first1, mid1); 188 values_first2 += thrust::distance(r.keys_first2, mid2); 189 keys_result += thrust::distance(r.keys_first1, mid1) + thrust::distance(r.keys_first2, mid2); 190 values_result += thrust::distance(r.keys_first1, mid1) + thrust::distance(r.keys_first2, mid2); 191 } 192 193 bool empty(void) const 194 { 195 return (keys_first1 == keys_last1) && (keys_first2 == keys_last2); 196 } 197 198 bool is_divisible(void) const 199 { 200 return static_cast<size_t>(thrust::distance(keys_first1, keys_last1) + thrust::distance(keys_first2, keys_last2)) > grain_size; 201 } 202}; 203 204struct body 205{ 206 template <typename Range> 207 void operator()(Range& r) const 208 { 209 thrust::merge_by_key(thrust::seq, 210 r.keys_first1, r.keys_last1, 211 r.keys_first2, r.keys_last2, 212 r.values_first1, 213 r.values_first2, 214 r.keys_result, 215 r.values_result, 216 r.comp); 217 } 218}; 219 220} // end namespace merge_by_key_detail 221 222 223template<typename DerivedPolicy, 224 typename InputIterator1, 225 typename InputIterator2, 226 typename OutputIterator, 227 typename StrictWeakOrdering> 228OutputIterator merge(execution_policy<DerivedPolicy> &exec, 229 InputIterator1 first1, 230 InputIterator1 last1, 231 InputIterator2 first2, 232 InputIterator2 last2, 233 OutputIterator result, 234 StrictWeakOrdering comp) 235{ 236 typedef typename merge_detail::range<InputIterator1,InputIterator2,OutputIterator,StrictWeakOrdering> Range; 237 typedef merge_detail::body Body; 238 Range range(first1, last1, first2, last2, result, comp); 239 Body body; 240 241 ::tbb::parallel_for(range, body); 242 243 thrust::advance(result, thrust::distance(first1, last1) + thrust::distance(first2, last2)); 244 245 return result; 246} // end merge() 247 248template <typename DerivedPolicy, 249 typename InputIterator1, 250 typename InputIterator2, 251 typename InputIterator3, 252 typename InputIterator4, 253 typename OutputIterator1, 254 typename OutputIterator2, 255 typename StrictWeakOrdering> 256thrust::pair<OutputIterator1,OutputIterator2> 257 merge_by_key(execution_policy<DerivedPolicy> &exec, 258 InputIterator1 keys_first1, 259 InputIterator1 keys_last1, 260 InputIterator2 keys_first2, 261 InputIterator2 keys_last2, 262 InputIterator3 values_first3, 263 InputIterator4 values_first4, 264 OutputIterator1 keys_result, 265 OutputIterator2 values_result, 266 StrictWeakOrdering comp) 267{ 268 typedef typename merge_by_key_detail::range<InputIterator1,InputIterator2,InputIterator3,InputIterator4,OutputIterator1,OutputIterator2,StrictWeakOrdering> Range; 269 typedef merge_by_key_detail::body Body; 270 271 Range range(keys_first1, keys_last1, keys_first2, keys_last2, values_first3, values_first4, keys_result, values_result, comp); 272 Body body; 273 274 ::tbb::parallel_for(range, body); 275 276 thrust::advance(keys_result, thrust::distance(keys_first1, keys_last1) + thrust::distance(keys_first2, keys_last2)); 277 thrust::advance(values_result, thrust::distance(keys_first1, keys_last1) + thrust::distance(keys_first2, keys_last2)); 278 279 return thrust::make_pair(keys_result,values_result); 280} 281 282} // end namespace detail 283} // end namespace tbb 284} // end namespace system 285} // end namespace thrust 286 287