1 /* 2 //@HEADER 3 // ************************************************************************ 4 // 5 // Kokkos v. 3.0 6 // Copyright (2020) National Technology & Engineering 7 // Solutions of Sandia, LLC (NTESS). 8 // 9 // Under the terms of Contract DE-NA0003525 with NTESS, 10 // the U.S. Government retains certain rights in this software. 11 // 12 // Redistribution and use in source and binary forms, with or without 13 // modification, are permitted provided that the following conditions are 14 // met: 15 // 16 // 1. Redistributions of source code must retain the above copyright 17 // notice, this list of conditions and the following disclaimer. 18 // 19 // 2. Redistributions in binary form must reproduce the above copyright 20 // notice, this list of conditions and the following disclaimer in the 21 // documentation and/or other materials provided with the distribution. 22 // 23 // 3. Neither the name of the Corporation nor the names of the 24 // contributors may be used to endorse or promote products derived from 25 // this software without specific prior written permission. 26 // 27 // THIS SOFTWARE IS PROVIDED BY NTESS "AS IS" AND ANY 28 // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 29 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 30 // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NTESS OR THE 31 // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 32 // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 33 // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 34 // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 35 // LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 36 // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 37 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 38 // 39 // Questions? Contact Christian R. Trott (crtrott@sandia.gov) 40 // 41 // ************************************************************************ 42 //@HEADER 43 */ 44 45 #ifndef KOKKO_SYCL_PARALLEL_SCAN_HPP 46 #define KOKKO_SYCL_PARALLEL_SCAN_HPP 47 48 #include <Kokkos_Macros.hpp> 49 #include <memory> 50 #if defined(KOKKOS_ENABLE_SYCL) 51 52 namespace Kokkos { 53 namespace Impl { 54 55 template <class FunctorType, class... Traits> 56 class ParallelScanSYCLBase { 57 public: 58 using Policy = Kokkos::RangePolicy<Traits...>; 59 60 protected: 61 using Member = typename Policy::member_type; 62 using WorkTag = typename Policy::work_tag; 63 using WorkRange = typename Policy::WorkRange; 64 using LaunchBounds = typename Policy::launch_bounds; 65 66 using ValueTraits = Kokkos::Impl::FunctorValueTraits<FunctorType, WorkTag>; 67 using ValueInit = Kokkos::Impl::FunctorValueInit<FunctorType, WorkTag>; 68 using ValueJoin = Kokkos::Impl::FunctorValueJoin<FunctorType, WorkTag>; 69 using ValueOps = Kokkos::Impl::FunctorValueOps<FunctorType, WorkTag>; 70 71 public: 72 using pointer_type = typename ValueTraits::pointer_type; 73 using value_type = typename ValueTraits::value_type; 74 using reference_type = typename ValueTraits::reference_type; 75 using functor_type = FunctorType; 76 using size_type = Kokkos::Experimental::SYCL::size_type; 77 using index_type = typename Policy::index_type; 78 79 protected: 80 const FunctorType m_functor; 81 const Policy m_policy; 82 pointer_type m_scratch_space = nullptr; 83 84 private: 85 template <typename Functor> scan_internal(sycl::queue & q,const Functor & functor,pointer_type global_mem,std::size_t size) const86 void scan_internal(sycl::queue& q, const Functor& functor, 87 pointer_type global_mem, std::size_t size) const { 88 // FIXME_SYCL optimize 89 constexpr size_t wgroup_size = 32; 90 auto n_wgroups = (size + wgroup_size - 1) / wgroup_size; 91 92 // FIXME_SYCL The allocation should be handled by the execution space 93 auto deleter = [&q](value_type* ptr) { sycl::free(ptr, q); }; 94 std::unique_ptr<value_type[], decltype(deleter)> group_results_memory( 95 static_cast<pointer_type>(sycl::malloc(sizeof(value_type) * n_wgroups, 96 q, sycl::usm::alloc::shared)), 97 deleter); 98 auto group_results = group_results_memory.get(); 99 100 q.submit([&](sycl::handler& cgh) { 101 sycl::accessor<value_type, 1, sycl::access::mode::read_write, 102 sycl::access::target::local> 103 local_mem(sycl::range<1>(wgroup_size), cgh); 104 105 // FIXME_SYCL we get wrong results without this, not sure why 106 sycl::stream out(1, 1, cgh); 107 cgh.parallel_for( 108 sycl::nd_range<1>(n_wgroups * wgroup_size, wgroup_size), 109 [=](sycl::nd_item<1> item) { 110 const auto local_id = item.get_local_linear_id(); 111 const auto global_id = item.get_global_linear_id(); 112 113 // Initialize local memory 114 if (global_id < size) 115 ValueOps::copy(functor, &local_mem[local_id], 116 &global_mem[global_id]); 117 else 118 ValueInit::init(functor, &local_mem[local_id]); 119 item.barrier(sycl::access::fence_space::local_space); 120 121 // Perform workgroup reduction 122 for (size_t stride = 1; 2 * stride < wgroup_size + 1; stride *= 2) { 123 auto idx = 2 * stride * (local_id + 1) - 1; 124 if (idx < wgroup_size) 125 ValueJoin::join(functor, &local_mem[idx], 126 &local_mem[idx - stride]); 127 item.barrier(sycl::access::fence_space::local_space); 128 } 129 130 if (local_id == 0) { 131 if (n_wgroups > 1) 132 ValueOps::copy(functor, 133 &group_results[item.get_group_linear_id()], 134 &local_mem[wgroup_size - 1]); 135 else 136 ValueInit::init(functor, 137 &group_results[item.get_group_linear_id()]); 138 ValueInit::init(functor, &local_mem[wgroup_size - 1]); 139 } 140 141 // Add results to all items 142 for (size_t stride = wgroup_size / 2; stride > 0; stride /= 2) { 143 auto idx = 2 * stride * (local_id + 1) - 1; 144 if (idx < wgroup_size) { 145 value_type dummy; 146 ValueOps::copy(functor, &dummy, &local_mem[idx - stride]); 147 ValueOps::copy(functor, &local_mem[idx - stride], 148 &local_mem[idx]); 149 ValueJoin::join(functor, &local_mem[idx], &dummy); 150 } 151 item.barrier(sycl::access::fence_space::local_space); 152 } 153 154 // Write results to global memory 155 if (global_id < size) 156 ValueOps::copy(functor, &global_mem[global_id], 157 &local_mem[local_id]); 158 }); 159 }); 160 161 if (n_wgroups > 1) scan_internal(q, functor, group_results, n_wgroups); 162 m_policy.space().fence(); 163 164 q.submit([&](sycl::handler& cgh) { 165 cgh.parallel_for(sycl::nd_range<1>(n_wgroups * wgroup_size, wgroup_size), 166 [=](sycl::nd_item<1> item) { 167 const auto global_id = item.get_global_linear_id(); 168 if (global_id < size) 169 ValueJoin::join( 170 functor, &global_mem[global_id], 171 &group_results[item.get_group_linear_id()]); 172 }); 173 }); 174 m_policy.space().fence(); 175 } 176 177 template <typename Functor> sycl_direct_launch(const Functor & functor) const178 void sycl_direct_launch(const Functor& functor) const { 179 // Convenience references 180 const Kokkos::Experimental::SYCL& space = m_policy.space(); 181 Kokkos::Experimental::Impl::SYCLInternal& instance = 182 *space.impl_internal_space_instance(); 183 sycl::queue& q = *instance.m_queue; 184 185 const std::size_t len = m_policy.end() - m_policy.begin(); 186 187 // Initialize global memory 188 q.submit([&](sycl::handler& cgh) { 189 auto global_mem = m_scratch_space; 190 auto begin = m_policy.begin(); 191 cgh.parallel_for(sycl::range<1>(len), [=](sycl::item<1> item) { 192 const typename Policy::index_type id = 193 static_cast<typename Policy::index_type>(item.get_id()) + begin; 194 value_type update{}; 195 ValueInit::init(functor, &update); 196 if constexpr (std::is_same<WorkTag, void>::value) 197 functor(id, update, false); 198 else 199 functor(WorkTag(), id, update, false); 200 ValueOps::copy(functor, &global_mem[id], &update); 201 }); 202 }); 203 space.fence(); 204 205 // Perform the actual exlcusive scan 206 scan_internal(q, functor, m_scratch_space, len); 207 208 // Write results to global memory 209 q.submit([&](sycl::handler& cgh) { 210 auto global_mem = m_scratch_space; 211 cgh.parallel_for(sycl::range<1>(len), [=](sycl::item<1> item) { 212 auto global_id = item.get_id(); 213 214 value_type update = global_mem[global_id]; 215 if constexpr (std::is_same<WorkTag, void>::value) 216 functor(global_id, update, true); 217 else 218 functor(WorkTag(), global_id, update, true); 219 ValueOps::copy(functor, &global_mem[global_id], &update); 220 }); 221 }); 222 space.fence(); 223 } 224 225 public: 226 template <typename PostFunctor> impl_execute(const PostFunctor & post_functor)227 void impl_execute(const PostFunctor& post_functor) { 228 if (m_policy.begin() == m_policy.end()) return; 229 230 const auto& q = *m_policy.space().impl_internal_space_instance()->m_queue; 231 const std::size_t len = m_policy.end() - m_policy.begin(); 232 233 // FIXME_SYCL The allocation should be handled by the execution space 234 // consider only storing one value per block and recreate initial results in 235 // the end before doing the final pass 236 auto deleter = [&q](value_type* ptr) { sycl::free(ptr, q); }; 237 std::unique_ptr<value_type[], decltype(deleter)> result_memory( 238 static_cast<pointer_type>(sycl::malloc(sizeof(value_type) * len, q, 239 sycl::usm::alloc::shared)), 240 deleter); 241 m_scratch_space = result_memory.get(); 242 243 Kokkos::Experimental::Impl::SYCLInternal::IndirectKernelMem& 244 indirectKernelMem = m_policy.space() 245 .impl_internal_space_instance() 246 ->m_indirectKernelMem; 247 248 const auto functor_wrapper = Experimental::Impl::make_sycl_function_wrapper( 249 m_functor, indirectKernelMem); 250 251 sycl_direct_launch(functor_wrapper.get_functor()); 252 post_functor(); 253 } 254 ParallelScanSYCLBase(const FunctorType & arg_functor,const Policy & arg_policy)255 ParallelScanSYCLBase(const FunctorType& arg_functor, const Policy& arg_policy) 256 : m_functor(arg_functor), m_policy(arg_policy) {} 257 }; 258 259 template <class FunctorType, class... Traits> 260 class ParallelScan<FunctorType, Kokkos::RangePolicy<Traits...>, 261 Kokkos::Experimental::SYCL> 262 : private ParallelScanSYCLBase<FunctorType, Traits...> { 263 public: 264 using Base = ParallelScanSYCLBase<FunctorType, Traits...>; 265 execute()266 inline void execute() { 267 Base::impl_execute([]() {}); 268 } 269 ParallelScan(const FunctorType & arg_functor,const typename Base::Policy & arg_policy)270 ParallelScan(const FunctorType& arg_functor, 271 const typename Base::Policy& arg_policy) 272 : Base(arg_functor, arg_policy) {} 273 }; 274 275 //---------------------------------------------------------------------------- 276 277 template <class FunctorType, class ReturnType, class... Traits> 278 class ParallelScanWithTotal<FunctorType, Kokkos::RangePolicy<Traits...>, 279 ReturnType, Kokkos::Experimental::SYCL> 280 : private ParallelScanSYCLBase<FunctorType, Traits...> { 281 public: 282 using Base = ParallelScanSYCLBase<FunctorType, Traits...>; 283 284 ReturnType& m_returnvalue; 285 execute()286 inline void execute() { 287 Base::impl_execute([&]() { 288 const long long nwork = Base::m_policy.end() - Base::m_policy.begin(); 289 if (nwork > 0) { 290 const int size = Base::ValueTraits::value_size(Base::m_functor); 291 DeepCopy<HostSpace, Kokkos::Experimental::SYCLDeviceUSMSpace>( 292 &m_returnvalue, Base::m_scratch_space + nwork - 1, size); 293 } 294 }); 295 } 296 ParallelScanWithTotal(const FunctorType & arg_functor,const typename Base::Policy & arg_policy,ReturnType & arg_returnvalue)297 ParallelScanWithTotal(const FunctorType& arg_functor, 298 const typename Base::Policy& arg_policy, 299 ReturnType& arg_returnvalue) 300 : Base(arg_functor, arg_policy), m_returnvalue(arg_returnvalue) {} 301 }; 302 303 } // namespace Impl 304 } // namespace Kokkos 305 306 #endif 307 308 #endif 309