1 /*******************************************************************************
2 * Copyright 2019-2020 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #ifndef GPU_COMPUTE_KERNEL_ARG_LIST_HPP
18 #define GPU_COMPUTE_KERNEL_ARG_LIST_HPP
19 
20 #include <cassert>
21 #include <cstddef>
22 #include <type_traits>
23 
24 #include "common/bfloat16.hpp"
25 #include "common/float16.hpp"
26 #include "common/memory_storage.hpp"
27 #include "common/nstl.hpp"
28 
29 #include "gpu/zero_pad_struct.h"
30 
31 namespace dnnl {
32 namespace impl {
33 namespace gpu {
34 namespace compute {
35 
36 enum class kernel_arg_kind_t {
37     undef,
38     global,
39     local,
40     scalar,
41     svm,
42 };
43 
44 enum class scalar_type_t {
45     undef,
46     _char,
47     _bfloat16,
48     _float,
49     _half,
50     _int,
51     _long,
52     _short,
53     _uchar,
54     _uint,
55     _ulong,
56     _ushort,
57     _zero_pad_mask_t,
58 };
59 
60 template <typename T>
61 struct scalar_type_traits {};
62 
63 template <>
64 struct scalar_type_traits<float16_t> {
65     static const auto type = scalar_type_t::_half;
66 };
67 template <>
68 struct scalar_type_traits<bfloat16_t> {
69     static const auto type = scalar_type_t::_bfloat16;
70 };
71 template <>
72 struct scalar_type_traits<float> {
73     static const auto type = scalar_type_t::_float;
74 };
75 
76 template <>
77 struct scalar_type_traits<uint8_t> {
78     static const auto type = scalar_type_t::_uchar;
79 };
80 template <>
81 struct scalar_type_traits<uint16_t> {
82     static const auto type = scalar_type_t::_ushort;
83 };
84 template <>
85 struct scalar_type_traits<uint32_t> {
86     static const auto type = scalar_type_t::_uint;
87 };
88 template <>
89 struct scalar_type_traits<uint64_t> {
90     static const auto type = scalar_type_t::_ulong;
91 };
92 
93 template <>
94 struct scalar_type_traits<int8_t> {
95     static const auto type = scalar_type_t::_char;
96 };
97 template <>
98 struct scalar_type_traits<int16_t> {
99     static const auto type = scalar_type_t::_short;
100 };
101 template <>
102 struct scalar_type_traits<int32_t> {
103     static const auto type = scalar_type_t::_int;
104 };
105 template <>
106 struct scalar_type_traits<int64_t> {
107     static const auto type = scalar_type_t::_long;
108 };
109 template <>
110 struct scalar_type_traits<zero_pad_mask_t> {
111     static const auto type = scalar_type_t::_zero_pad_mask_t;
112 };
113 
114 class kernel_arg_t {
115 public:
kind() const116     kernel_arg_kind_t kind() const { return kind_; }
scalar_type() const117     scalar_type_t scalar_type() const { return scalar_type_; }
size() const118     size_t size() const { return size_; }
119 
is_global() const120     bool is_global() const { return kind() == kernel_arg_kind_t::global; }
is_local() const121     bool is_local() const { return kind() == kernel_arg_kind_t::local; }
is_svm_pointer() const122     bool is_svm_pointer() const { return kind_ == kernel_arg_kind_t::svm; }
123 
set_value(const memory_storage_t & storage)124     kernel_arg_t &set_value(const memory_storage_t &storage) {
125         kind_ = kernel_arg_kind_t::global;
126         size_ = 0;
127         value_ = static_cast<const void *>(&storage);
128         return *this;
129     }
130 
131     template <typename T>
set_value(const T & value,void * & data_pool)132     kernel_arg_t &set_value(const T &value, void *&data_pool) {
133         assert(size_ <= sizeof(T));
134         if (value_ == nullptr) {
135             assert(data_pool != nullptr);
136             size_ = sizeof(T);
137             data_pool = utils::align_ptr(data_pool, alignof(T));
138             value_ = data_pool;
139             data_pool = static_cast<char *>(data_pool) + size_;
140         }
141         kind_ = kernel_arg_kind_t::scalar;
142         scalar_type_ = scalar_type_traits<T>::type;
143         new (const_cast<void *>(value_)) T(value);
144         return *this;
145     }
146 
set_value(size_t size,std::nullptr_t)147     kernel_arg_t &set_value(size_t size, std::nullptr_t) {
148         kind_ = kernel_arg_kind_t::local;
149         size_ = size;
150         value_ = nullptr;
151         return *this;
152     }
153 
set_value(void * svm_ptr,kernel_arg_kind_t kind)154     void set_value(void *svm_ptr, kernel_arg_kind_t kind) {
155         assert(kind == kernel_arg_kind_t::svm);
156         kind_ = kernel_arg_kind_t::svm;
157         size_ = 0;
158         value_ = svm_ptr;
159     }
160 
value() const161     const void *value() const {
162         assert(kind() != kernel_arg_kind_t::undef);
163         return value_;
164     }
165 
166     template <typename T>
as() const167     T as() const {
168         assert(kind() == kernel_arg_kind_t::scalar);
169         assert(scalar_type() == scalar_type_traits<T>::type);
170         return *(const T *)value();
171     }
172 
173     static kernel_arg_t cast(scalar_type_t other_type,
174             const kernel_arg_t &other, void *&cast_storage);
175 
176 private:
177     kernel_arg_kind_t kind_ = kernel_arg_kind_t::undef;
178     scalar_type_t scalar_type_ = scalar_type_t::undef;
179     size_t size_ = 0;
180     const void *value_ = nullptr;
181 };
182 
183 class kernel_arg_list_t {
184 public:
kernel_arg_list_t()185     kernel_arg_list_t() { nargs_ = 0; }
set(int index,const memory_storage_t & storage)186     void set(int index, const memory_storage_t &storage) {
187         assert(index < max_args);
188         nargs_ = nstl::max(nargs_, index + 1);
189         args_[index].set_value(storage);
190     }
191 
set(int index,void * value,kernel_arg_kind_t kind)192     void set(int index, void *value, kernel_arg_kind_t kind) {
193         assert(index < max_args);
194         nargs_ = nstl::max(nargs_, index + 1);
195         args_[index].set_value(value, kind);
196     }
197 
198     template <class T>
set(int index,const T & value)199     void set(int index, const T &value) {
200         assert(index < max_args);
201         nargs_ = nstl::max(nargs_, index + 1);
202         args_[index].set_value(value, unused_storage);
203 
204         assert(unused_storage
205                 <= reinterpret_cast<char *>(&scalar_storage_) + storage_size);
206     }
207 
set(int index,size_t size,std::nullptr_t)208     void set(int index, size_t size, std::nullptr_t) {
209         assert(index < max_args);
210         nargs_ = nstl::max(nargs_, index + 1);
211         args_[index].set_value(size, nullptr);
212     }
213 
nargs() const214     int nargs() const { return nargs_; }
215 
get(int index) const216     const kernel_arg_t &get(int index) const {
217         assert(index < nargs());
218         return args_[index];
219     }
220 
get_memory_storage(int index) const221     const memory_storage_t &get_memory_storage(int index) const {
222         assert(args_[index].kind() == kernel_arg_kind_t::global);
223         return *static_cast<const memory_storage_t *>(args_[index].value());
224     }
225 
226 private:
227     static constexpr int max_args = 96;
228     static constexpr int storage_size = 512;
229     static constexpr int storage_alginment = 8;
230 
231     int nargs_ = 0;
232     kernel_arg_t args_[max_args];
233     typename std::aligned_storage<storage_size, storage_alginment>::type
234             scalar_storage_;
235     void *unused_storage = &scalar_storage_;
236 
237     kernel_arg_list_t(const kernel_arg_list_t &) = delete;
238     kernel_arg_list_t(kernel_arg_list_t &&) = delete;
239     kernel_arg_list_t &operator=(const kernel_arg_list_t &) = delete;
240     kernel_arg_list_t &operator=(kernel_arg_list_t &&) = delete;
241 };
242 
243 } // namespace compute
244 } // namespace gpu
245 } // namespace impl
246 } // namespace dnnl
247 
248 #endif // GPU_COMPUTE_KERNEL_ARG_LIST_HPP
249