1 //---------------------------------------------------------------------------//
2 // Copyright (c) 2014 Roshan <thisisroshansmail@gmail.com>
3 //
4 // Distributed under the Boost Software License, Version 1.0
5 // See accompanying file LICENSE_1_0.txt or copy at
6 // http://www.boost.org/LICENSE_1_0.txt
7 //
8 // See http://boostorg.github.com/compute for more information.
9 //---------------------------------------------------------------------------//
10 
11 #ifndef BOOST_COMPUTE_ALGORITHM_DETAIL_MERGE_WIH_MERGE_PATH_HPP
12 #define BOOST_COMPUTE_ALGORITHM_DETAIL_MERGE_WIH_MERGE_PATH_HPP
13 
14 #include <iterator>
15 
16 #include <boost/compute/algorithm/detail/merge_path.hpp>
17 #include <boost/compute/algorithm/fill_n.hpp>
18 #include <boost/compute/container/vector.hpp>
19 #include <boost/compute/detail/iterator_range_size.hpp>
20 #include <boost/compute/detail/meta_kernel.hpp>
21 #include <boost/compute/system.hpp>
22 
23 namespace boost {
24 namespace compute {
25 namespace detail {
26 
27 ///
28 /// \brief Serial merge kernel class
29 ///
30 /// Subclass of meta_kernel to perform serial merge after tiling
31 ///
32 class serial_merge_kernel : meta_kernel
33 {
34 public:
35     unsigned int tile_size;
36 
serial_merge_kernel()37     serial_merge_kernel() : meta_kernel("merge")
38     {
39         tile_size = 4;
40     }
41 
42     template<class InputIterator1, class InputIterator2,
43              class InputIterator3, class InputIterator4,
44              class OutputIterator, class Compare>
set_range(InputIterator1 first1,InputIterator2 first2,InputIterator3 tile_first1,InputIterator3 tile_last1,InputIterator4 tile_first2,OutputIterator result,Compare comp)45     void set_range(InputIterator1 first1,
46                    InputIterator2 first2,
47                    InputIterator3 tile_first1,
48                    InputIterator3 tile_last1,
49                    InputIterator4 tile_first2,
50                    OutputIterator result,
51                    Compare comp)
52     {
53         m_count = iterator_range_size(tile_first1, tile_last1) - 1;
54 
55         *this <<
56         "uint i = get_global_id(0);\n" <<
57         "uint start1 = " << tile_first1[expr<uint_>("i")] << ";\n" <<
58         "uint end1 = " << tile_first1[expr<uint_>("i+1")] << ";\n" <<
59         "uint start2 = " << tile_first2[expr<uint_>("i")] << ";\n" <<
60         "uint end2 = " << tile_first2[expr<uint_>("i+1")] << ";\n" <<
61         "uint index = i*" << tile_size << ";\n" <<
62         "while(start1<end1 && start2<end2)\n" <<
63         "{\n" <<
64         "   if(!(" << comp(first2[expr<uint_>("start2")],
65                             first1[expr<uint_>("start1")]) << "))\n" <<
66         "   {\n" <<
67                 result[expr<uint_>("index")] <<
68                     " = " << first1[expr<uint_>("start1")] << ";\n" <<
69         "       index++;\n" <<
70         "       start1++;\n" <<
71         "   }\n" <<
72         "   else\n" <<
73         "   {\n" <<
74                 result[expr<uint_>("index")] <<
75                     " = " << first2[expr<uint_>("start2")] << ";\n" <<
76         "       index++;\n" <<
77         "       start2++;\n" <<
78         "   }\n" <<
79         "}\n" <<
80         "while(start1<end1)\n" <<
81         "{\n" <<
82             result[expr<uint_>("index")] <<
83                 " = " << first1[expr<uint_>("start1")] << ";\n" <<
84         "   index++;\n" <<
85         "   start1++;\n" <<
86         "}\n" <<
87         "while(start2<end2)\n" <<
88         "{\n" <<
89             result[expr<uint_>("index")] <<
90                 " = " << first2[expr<uint_>("start2")] << ";\n" <<
91         "   index++;\n" <<
92         "   start2++;\n" <<
93         "}\n";
94     }
95 
96     template<class InputIterator1, class InputIterator2,
97              class InputIterator3, class InputIterator4,
98              class OutputIterator>
set_range(InputIterator1 first1,InputIterator2 first2,InputIterator3 tile_first1,InputIterator3 tile_last1,InputIterator4 tile_first2,OutputIterator result)99     void set_range(InputIterator1 first1,
100                    InputIterator2 first2,
101                    InputIterator3 tile_first1,
102                    InputIterator3 tile_last1,
103                    InputIterator4 tile_first2,
104                    OutputIterator result)
105     {
106         typedef typename std::iterator_traits<InputIterator1>::value_type value_type;
107         ::boost::compute::less<value_type> less_than;
108         set_range(first1, first2, tile_first1, tile_last1, tile_first2, result, less_than);
109     }
110 
exec(command_queue & queue)111     event exec(command_queue &queue)
112     {
113         if(m_count == 0) {
114             return event();
115         }
116 
117         return exec_1d(queue, 0, m_count);
118     }
119 
120 private:
121     size_t m_count;
122 };
123 
124 ///
125 /// \brief Merge algorithm with merge path
126 ///
127 /// Merges the sorted values in the range [\p first1, \p last1) with
128 /// the sorted values in the range [\p first2, last2) and stores the
129 /// result in the range beginning at \p result
130 ///
131 /// \param first1 Iterator pointing to start of first set
132 /// \param last1 Iterator pointing to end of first set
133 /// \param first2 Iterator pointing to start of second set
134 /// \param last2 Iterator pointing to end of second set
135 /// \param result Iterator pointing to start of range in which the result
136 /// will be stored
137 /// \param comp Comparator which performs less than function
138 /// \param queue Queue on which to execute
139 ///
140 template<class InputIterator1, class InputIterator2, class OutputIterator, class Compare>
141 inline OutputIterator
merge_with_merge_path(InputIterator1 first1,InputIterator1 last1,InputIterator2 first2,InputIterator2 last2,OutputIterator result,Compare comp,command_queue & queue=system::default_queue ())142 merge_with_merge_path(InputIterator1 first1,
143                       InputIterator1 last1,
144                       InputIterator2 first2,
145                       InputIterator2 last2,
146                       OutputIterator result,
147                       Compare comp,
148                       command_queue &queue = system::default_queue())
149 {
150     typedef typename
151         std::iterator_traits<OutputIterator>::difference_type result_difference_type;
152 
153     size_t tile_size = 1024;
154 
155     size_t count1 = iterator_range_size(first1, last1);
156     size_t count2 = iterator_range_size(first2, last2);
157 
158     vector<uint_> tile_a((count1+count2+tile_size-1)/tile_size+1, queue.get_context());
159     vector<uint_> tile_b((count1+count2+tile_size-1)/tile_size+1, queue.get_context());
160 
161     // Tile the sets
162     merge_path_kernel tiling_kernel;
163     tiling_kernel.tile_size = static_cast<unsigned int>(tile_size);
164     tiling_kernel.set_range(first1, last1, first2, last2,
165                             tile_a.begin()+1, tile_b.begin()+1, comp);
166     fill_n(tile_a.begin(), 1, uint_(0), queue);
167     fill_n(tile_b.begin(), 1, uint_(0), queue);
168     tiling_kernel.exec(queue);
169 
170     fill_n(tile_a.end()-1, 1, static_cast<uint_>(count1), queue);
171     fill_n(tile_b.end()-1, 1, static_cast<uint_>(count2), queue);
172 
173     // Merge
174     serial_merge_kernel merge_kernel;
175     merge_kernel.tile_size = static_cast<unsigned int>(tile_size);
176     merge_kernel.set_range(first1, first2, tile_a.begin(), tile_a.end(),
177                            tile_b.begin(), result, comp);
178 
179     merge_kernel.exec(queue);
180 
181     return result + static_cast<result_difference_type>(count1 + count2);
182 }
183 
184 /// \overload
185 template<class InputIterator1, class InputIterator2, class OutputIterator>
186 inline OutputIterator
merge_with_merge_path(InputIterator1 first1,InputIterator1 last1,InputIterator2 first2,InputIterator2 last2,OutputIterator result,command_queue & queue=system::default_queue ())187 merge_with_merge_path(InputIterator1 first1,
188                       InputIterator1 last1,
189                       InputIterator2 first2,
190                       InputIterator2 last2,
191                       OutputIterator result,
192                       command_queue &queue = system::default_queue())
193 {
194     typedef typename std::iterator_traits<InputIterator1>::value_type value_type;
195     ::boost::compute::less<value_type> less_than;
196     return merge_with_merge_path(first1, last1, first2, last2, result, less_than, queue);
197 }
198 
199 } //end detail namespace
200 } //end compute namespace
201 } //end boost namespace
202 
203 #endif // BOOST_COMPUTE_ALGORITHM_DETAIL_MERGE_WIH_MERGE_PATH_HPP
204