1 /******************************************************************************* 2 * Copyright 2019-2020 Intel Corporation 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 *******************************************************************************/ 16 17 #ifndef GPU_COMPUTE_KERNEL_ARG_LIST_HPP 18 #define GPU_COMPUTE_KERNEL_ARG_LIST_HPP 19 20 #include <cassert> 21 #include <cstddef> 22 #include <type_traits> 23 24 #include "common/bfloat16.hpp" 25 #include "common/float16.hpp" 26 #include "common/memory_storage.hpp" 27 #include "common/nstl.hpp" 28 29 #include "gpu/zero_pad_struct.h" 30 31 namespace dnnl { 32 namespace impl { 33 namespace gpu { 34 namespace compute { 35 36 enum class kernel_arg_kind_t { 37 undef, 38 global, 39 local, 40 scalar, 41 svm, 42 }; 43 44 enum class scalar_type_t { 45 undef, 46 _char, 47 _bfloat16, 48 _float, 49 _half, 50 _int, 51 _long, 52 _short, 53 _uchar, 54 _uint, 55 _ulong, 56 _ushort, 57 _zero_pad_mask_t, 58 }; 59 60 template <typename T> 61 struct scalar_type_traits {}; 62 63 template <> 64 struct scalar_type_traits<float16_t> { 65 static const auto type = scalar_type_t::_half; 66 }; 67 template <> 68 struct scalar_type_traits<bfloat16_t> { 69 static const auto type = scalar_type_t::_bfloat16; 70 }; 71 template <> 72 struct scalar_type_traits<float> { 73 static const auto type = scalar_type_t::_float; 74 }; 75 76 template <> 77 struct scalar_type_traits<uint8_t> { 78 static const auto type = scalar_type_t::_uchar; 79 }; 80 template <> 81 struct scalar_type_traits<uint16_t> { 82 static const auto type = scalar_type_t::_ushort; 83 }; 84 template <> 85 struct scalar_type_traits<uint32_t> { 86 static const auto type = scalar_type_t::_uint; 87 }; 88 template <> 89 struct scalar_type_traits<uint64_t> { 90 static const auto type = scalar_type_t::_ulong; 91 }; 92 93 template <> 94 struct scalar_type_traits<int8_t> { 95 static const auto type = scalar_type_t::_char; 96 }; 97 template <> 98 struct scalar_type_traits<int16_t> { 99 static const auto type = scalar_type_t::_short; 100 }; 101 template <> 102 struct scalar_type_traits<int32_t> { 103 static const auto type = scalar_type_t::_int; 104 }; 105 template <> 106 struct scalar_type_traits<int64_t> { 107 static const auto type = scalar_type_t::_long; 108 }; 109 template <> 110 struct scalar_type_traits<zero_pad_mask_t> { 111 static const auto type = scalar_type_t::_zero_pad_mask_t; 112 }; 113 114 class kernel_arg_t { 115 public: kind() const116 kernel_arg_kind_t kind() const { return kind_; } scalar_type() const117 scalar_type_t scalar_type() const { return scalar_type_; } size() const118 size_t size() const { return size_; } 119 is_global() const120 bool is_global() const { return kind() == kernel_arg_kind_t::global; } is_local() const121 bool is_local() const { return kind() == kernel_arg_kind_t::local; } is_svm_pointer() const122 bool is_svm_pointer() const { return kind_ == kernel_arg_kind_t::svm; } 123 set_value(const memory_storage_t & storage)124 kernel_arg_t &set_value(const memory_storage_t &storage) { 125 kind_ = kernel_arg_kind_t::global; 126 size_ = 0; 127 value_ = static_cast<const void *>(&storage); 128 return *this; 129 } 130 131 template <typename T> set_value(const T & value,void * & data_pool)132 kernel_arg_t &set_value(const T &value, void *&data_pool) { 133 assert(size_ <= sizeof(T)); 134 if (value_ == nullptr) { 135 assert(data_pool != nullptr); 136 size_ = sizeof(T); 137 data_pool = utils::align_ptr(data_pool, alignof(T)); 138 value_ = data_pool; 139 data_pool = static_cast<char *>(data_pool) + size_; 140 } 141 kind_ = kernel_arg_kind_t::scalar; 142 scalar_type_ = scalar_type_traits<T>::type; 143 new (const_cast<void *>(value_)) T(value); 144 return *this; 145 } 146 set_value(size_t size,std::nullptr_t)147 kernel_arg_t &set_value(size_t size, std::nullptr_t) { 148 kind_ = kernel_arg_kind_t::local; 149 size_ = size; 150 value_ = nullptr; 151 return *this; 152 } 153 set_value(void * svm_ptr,kernel_arg_kind_t kind)154 void set_value(void *svm_ptr, kernel_arg_kind_t kind) { 155 assert(kind == kernel_arg_kind_t::svm); 156 kind_ = kernel_arg_kind_t::svm; 157 size_ = 0; 158 value_ = svm_ptr; 159 } 160 value() const161 const void *value() const { 162 assert(kind() != kernel_arg_kind_t::undef); 163 return value_; 164 } 165 166 template <typename T> as() const167 T as() const { 168 assert(kind() == kernel_arg_kind_t::scalar); 169 assert(scalar_type() == scalar_type_traits<T>::type); 170 return *(const T *)value(); 171 } 172 173 static kernel_arg_t cast(scalar_type_t other_type, 174 const kernel_arg_t &other, void *&cast_storage); 175 176 private: 177 kernel_arg_kind_t kind_ = kernel_arg_kind_t::undef; 178 scalar_type_t scalar_type_ = scalar_type_t::undef; 179 size_t size_ = 0; 180 const void *value_ = nullptr; 181 }; 182 183 class kernel_arg_list_t { 184 public: kernel_arg_list_t()185 kernel_arg_list_t() { nargs_ = 0; } set(int index,const memory_storage_t & storage)186 void set(int index, const memory_storage_t &storage) { 187 assert(index < max_args); 188 nargs_ = nstl::max(nargs_, index + 1); 189 args_[index].set_value(storage); 190 } 191 set(int index,void * value,kernel_arg_kind_t kind)192 void set(int index, void *value, kernel_arg_kind_t kind) { 193 assert(index < max_args); 194 nargs_ = nstl::max(nargs_, index + 1); 195 args_[index].set_value(value, kind); 196 } 197 198 template <class T> set(int index,const T & value)199 void set(int index, const T &value) { 200 assert(index < max_args); 201 nargs_ = nstl::max(nargs_, index + 1); 202 args_[index].set_value(value, unused_storage); 203 204 assert(unused_storage 205 <= reinterpret_cast<char *>(&scalar_storage_) + storage_size); 206 } 207 set(int index,size_t size,std::nullptr_t)208 void set(int index, size_t size, std::nullptr_t) { 209 assert(index < max_args); 210 nargs_ = nstl::max(nargs_, index + 1); 211 args_[index].set_value(size, nullptr); 212 } 213 nargs() const214 int nargs() const { return nargs_; } 215 get(int index) const216 const kernel_arg_t &get(int index) const { 217 assert(index < nargs()); 218 return args_[index]; 219 } 220 get_memory_storage(int index) const221 const memory_storage_t &get_memory_storage(int index) const { 222 assert(args_[index].kind() == kernel_arg_kind_t::global); 223 return *static_cast<const memory_storage_t *>(args_[index].value()); 224 } 225 226 private: 227 static constexpr int max_args = 96; 228 static constexpr int storage_size = 512; 229 static constexpr int storage_alginment = 8; 230 231 int nargs_ = 0; 232 kernel_arg_t args_[max_args]; 233 typename std::aligned_storage<storage_size, storage_alginment>::type 234 scalar_storage_; 235 void *unused_storage = &scalar_storage_; 236 237 kernel_arg_list_t(const kernel_arg_list_t &) = delete; 238 kernel_arg_list_t(kernel_arg_list_t &&) = delete; 239 kernel_arg_list_t &operator=(const kernel_arg_list_t &) = delete; 240 kernel_arg_list_t &operator=(kernel_arg_list_t &&) = delete; 241 }; 242 243 } // namespace compute 244 } // namespace gpu 245 } // namespace impl 246 } // namespace dnnl 247 248 #endif // GPU_COMPUTE_KERNEL_ARG_LIST_HPP 249