1 /*
2 //@HEADER
3 // ************************************************************************
4 //
5 //                        Kokkos v. 3.0
6 //       Copyright (2020) National Technology & Engineering
7 //               Solutions of Sandia, LLC (NTESS).
8 //
9 // Under the terms of Contract DE-NA0003525 with NTESS,
10 // the U.S. Government retains certain rights in this software.
11 //
12 // Redistribution and use in source and binary forms, with or without
13 // modification, are permitted provided that the following conditions are
14 // met:
15 //
16 // 1. Redistributions of source code must retain the above copyright
17 // notice, this list of conditions and the following disclaimer.
18 //
19 // 2. Redistributions in binary form must reproduce the above copyright
20 // notice, this list of conditions and the following disclaimer in the
21 // documentation and/or other materials provided with the distribution.
22 //
23 // 3. Neither the name of the Corporation nor the names of the
24 // contributors may be used to endorse or promote products derived from
25 // this software without specific prior written permission.
26 //
27 // THIS SOFTWARE IS PROVIDED BY NTESS "AS IS" AND ANY
28 // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
29 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
30 // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NTESS OR THE
31 // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
32 // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
33 // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
34 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
35 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
36 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
37 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
38 //
39 // Questions? Contact Christian R. Trott (crtrott@sandia.gov)
40 //
41 // ************************************************************************
42 //@HEADER
43 */
44 
45 #ifndef KOKKO_SYCL_PARALLEL_SCAN_HPP
46 #define KOKKO_SYCL_PARALLEL_SCAN_HPP
47 
48 #include <Kokkos_Macros.hpp>
49 #include <memory>
50 #if defined(KOKKOS_ENABLE_SYCL)
51 
52 namespace Kokkos {
53 namespace Impl {
54 
55 template <class FunctorType, class... Traits>
56 class ParallelScanSYCLBase {
57  public:
58   using Policy = Kokkos::RangePolicy<Traits...>;
59 
60  protected:
61   using Member       = typename Policy::member_type;
62   using WorkTag      = typename Policy::work_tag;
63   using WorkRange    = typename Policy::WorkRange;
64   using LaunchBounds = typename Policy::launch_bounds;
65 
66   using ValueTraits = Kokkos::Impl::FunctorValueTraits<FunctorType, WorkTag>;
67   using ValueInit   = Kokkos::Impl::FunctorValueInit<FunctorType, WorkTag>;
68   using ValueJoin   = Kokkos::Impl::FunctorValueJoin<FunctorType, WorkTag>;
69   using ValueOps    = Kokkos::Impl::FunctorValueOps<FunctorType, WorkTag>;
70 
71  public:
72   using pointer_type   = typename ValueTraits::pointer_type;
73   using value_type     = typename ValueTraits::value_type;
74   using reference_type = typename ValueTraits::reference_type;
75   using functor_type   = FunctorType;
76   using size_type      = Kokkos::Experimental::SYCL::size_type;
77   using index_type     = typename Policy::index_type;
78 
79  protected:
80   const FunctorType m_functor;
81   const Policy m_policy;
82   pointer_type m_scratch_space = nullptr;
83 
84  private:
85   template <typename Functor>
scan_internal(sycl::queue & q,const Functor & functor,pointer_type global_mem,std::size_t size) const86   void scan_internal(sycl::queue& q, const Functor& functor,
87                      pointer_type global_mem, std::size_t size) const {
88     // FIXME_SYCL optimize
89     constexpr size_t wgroup_size = 32;
90     auto n_wgroups               = (size + wgroup_size - 1) / wgroup_size;
91 
92     // FIXME_SYCL The allocation should be handled by the execution space
93     auto deleter = [&q](value_type* ptr) { sycl::free(ptr, q); };
94     std::unique_ptr<value_type[], decltype(deleter)> group_results_memory(
95         static_cast<pointer_type>(sycl::malloc(sizeof(value_type) * n_wgroups,
96                                                q, sycl::usm::alloc::shared)),
97         deleter);
98     auto group_results = group_results_memory.get();
99 
100     q.submit([&](sycl::handler& cgh) {
101       sycl::accessor<value_type, 1, sycl::access::mode::read_write,
102                      sycl::access::target::local>
103           local_mem(sycl::range<1>(wgroup_size), cgh);
104 
105       // FIXME_SYCL we get wrong results without this, not sure why
106       sycl::stream out(1, 1, cgh);
107       cgh.parallel_for(
108           sycl::nd_range<1>(n_wgroups * wgroup_size, wgroup_size),
109           [=](sycl::nd_item<1> item) {
110             const auto local_id  = item.get_local_linear_id();
111             const auto global_id = item.get_global_linear_id();
112 
113             // Initialize local memory
114             if (global_id < size)
115               ValueOps::copy(functor, &local_mem[local_id],
116                              &global_mem[global_id]);
117             else
118               ValueInit::init(functor, &local_mem[local_id]);
119             item.barrier(sycl::access::fence_space::local_space);
120 
121             // Perform workgroup reduction
122             for (size_t stride = 1; 2 * stride < wgroup_size + 1; stride *= 2) {
123               auto idx = 2 * stride * (local_id + 1) - 1;
124               if (idx < wgroup_size)
125                 ValueJoin::join(functor, &local_mem[idx],
126                                 &local_mem[idx - stride]);
127               item.barrier(sycl::access::fence_space::local_space);
128             }
129 
130             if (local_id == 0) {
131               if (n_wgroups > 1)
132                 ValueOps::copy(functor,
133                                &group_results[item.get_group_linear_id()],
134                                &local_mem[wgroup_size - 1]);
135               else
136                 ValueInit::init(functor,
137                                 &group_results[item.get_group_linear_id()]);
138               ValueInit::init(functor, &local_mem[wgroup_size - 1]);
139             }
140 
141             // Add results to all items
142             for (size_t stride = wgroup_size / 2; stride > 0; stride /= 2) {
143               auto idx = 2 * stride * (local_id + 1) - 1;
144               if (idx < wgroup_size) {
145                 value_type dummy;
146                 ValueOps::copy(functor, &dummy, &local_mem[idx - stride]);
147                 ValueOps::copy(functor, &local_mem[idx - stride],
148                                &local_mem[idx]);
149                 ValueJoin::join(functor, &local_mem[idx], &dummy);
150               }
151               item.barrier(sycl::access::fence_space::local_space);
152             }
153 
154             // Write results to global memory
155             if (global_id < size)
156               ValueOps::copy(functor, &global_mem[global_id],
157                              &local_mem[local_id]);
158           });
159     });
160 
161     if (n_wgroups > 1) scan_internal(q, functor, group_results, n_wgroups);
162     m_policy.space().fence();
163 
164     q.submit([&](sycl::handler& cgh) {
165       cgh.parallel_for(sycl::nd_range<1>(n_wgroups * wgroup_size, wgroup_size),
166                        [=](sycl::nd_item<1> item) {
167                          const auto global_id = item.get_global_linear_id();
168                          if (global_id < size)
169                            ValueJoin::join(
170                                functor, &global_mem[global_id],
171                                &group_results[item.get_group_linear_id()]);
172                        });
173     });
174     m_policy.space().fence();
175   }
176 
177   template <typename Functor>
sycl_direct_launch(const Functor & functor) const178   void sycl_direct_launch(const Functor& functor) const {
179     // Convenience references
180     const Kokkos::Experimental::SYCL& space = m_policy.space();
181     Kokkos::Experimental::Impl::SYCLInternal& instance =
182         *space.impl_internal_space_instance();
183     sycl::queue& q = *instance.m_queue;
184 
185     const std::size_t len = m_policy.end() - m_policy.begin();
186 
187     // Initialize global memory
188     q.submit([&](sycl::handler& cgh) {
189       auto global_mem = m_scratch_space;
190       auto begin      = m_policy.begin();
191       cgh.parallel_for(sycl::range<1>(len), [=](sycl::item<1> item) {
192         const typename Policy::index_type id =
193             static_cast<typename Policy::index_type>(item.get_id()) + begin;
194         value_type update{};
195         ValueInit::init(functor, &update);
196         if constexpr (std::is_same<WorkTag, void>::value)
197           functor(id, update, false);
198         else
199           functor(WorkTag(), id, update, false);
200         ValueOps::copy(functor, &global_mem[id], &update);
201       });
202     });
203     space.fence();
204 
205     // Perform the actual exlcusive scan
206     scan_internal(q, functor, m_scratch_space, len);
207 
208     // Write results to global memory
209     q.submit([&](sycl::handler& cgh) {
210       auto global_mem = m_scratch_space;
211       cgh.parallel_for(sycl::range<1>(len), [=](sycl::item<1> item) {
212         auto global_id = item.get_id();
213 
214         value_type update = global_mem[global_id];
215         if constexpr (std::is_same<WorkTag, void>::value)
216           functor(global_id, update, true);
217         else
218           functor(WorkTag(), global_id, update, true);
219         ValueOps::copy(functor, &global_mem[global_id], &update);
220       });
221     });
222     space.fence();
223   }
224 
225  public:
226   template <typename PostFunctor>
impl_execute(const PostFunctor & post_functor)227   void impl_execute(const PostFunctor& post_functor) {
228     if (m_policy.begin() == m_policy.end()) return;
229 
230     const auto& q = *m_policy.space().impl_internal_space_instance()->m_queue;
231     const std::size_t len = m_policy.end() - m_policy.begin();
232 
233     // FIXME_SYCL The allocation should be handled by the execution space
234     // consider only storing one value per block and recreate initial results in
235     // the end before doing the final pass
236     auto deleter = [&q](value_type* ptr) { sycl::free(ptr, q); };
237     std::unique_ptr<value_type[], decltype(deleter)> result_memory(
238         static_cast<pointer_type>(sycl::malloc(sizeof(value_type) * len, q,
239                                                sycl::usm::alloc::shared)),
240         deleter);
241     m_scratch_space = result_memory.get();
242 
243     Kokkos::Experimental::Impl::SYCLInternal::IndirectKernelMem&
244         indirectKernelMem = m_policy.space()
245                                 .impl_internal_space_instance()
246                                 ->m_indirectKernelMem;
247 
248     const auto functor_wrapper = Experimental::Impl::make_sycl_function_wrapper(
249         m_functor, indirectKernelMem);
250 
251     sycl_direct_launch(functor_wrapper.get_functor());
252     post_functor();
253   }
254 
ParallelScanSYCLBase(const FunctorType & arg_functor,const Policy & arg_policy)255   ParallelScanSYCLBase(const FunctorType& arg_functor, const Policy& arg_policy)
256       : m_functor(arg_functor), m_policy(arg_policy) {}
257 };
258 
259 template <class FunctorType, class... Traits>
260 class ParallelScan<FunctorType, Kokkos::RangePolicy<Traits...>,
261                    Kokkos::Experimental::SYCL>
262     : private ParallelScanSYCLBase<FunctorType, Traits...> {
263  public:
264   using Base = ParallelScanSYCLBase<FunctorType, Traits...>;
265 
execute()266   inline void execute() {
267     Base::impl_execute([]() {});
268   }
269 
ParallelScan(const FunctorType & arg_functor,const typename Base::Policy & arg_policy)270   ParallelScan(const FunctorType& arg_functor,
271                const typename Base::Policy& arg_policy)
272       : Base(arg_functor, arg_policy) {}
273 };
274 
275 //----------------------------------------------------------------------------
276 
277 template <class FunctorType, class ReturnType, class... Traits>
278 class ParallelScanWithTotal<FunctorType, Kokkos::RangePolicy<Traits...>,
279                             ReturnType, Kokkos::Experimental::SYCL>
280     : private ParallelScanSYCLBase<FunctorType, Traits...> {
281  public:
282   using Base = ParallelScanSYCLBase<FunctorType, Traits...>;
283 
284   ReturnType& m_returnvalue;
285 
execute()286   inline void execute() {
287     Base::impl_execute([&]() {
288       const long long nwork = Base::m_policy.end() - Base::m_policy.begin();
289       if (nwork > 0) {
290         const int size = Base::ValueTraits::value_size(Base::m_functor);
291         DeepCopy<HostSpace, Kokkos::Experimental::SYCLDeviceUSMSpace>(
292             &m_returnvalue, Base::m_scratch_space + nwork - 1, size);
293       }
294     });
295   }
296 
ParallelScanWithTotal(const FunctorType & arg_functor,const typename Base::Policy & arg_policy,ReturnType & arg_returnvalue)297   ParallelScanWithTotal(const FunctorType& arg_functor,
298                         const typename Base::Policy& arg_policy,
299                         ReturnType& arg_returnvalue)
300       : Base(arg_functor, arg_policy), m_returnvalue(arg_returnvalue) {}
301 };
302 
303 }  // namespace Impl
304 }  // namespace Kokkos
305 
306 #endif
307 
308 #endif
309