1 /******************************************************************************* 2 * Copyright 2019-2021 Intel Corporation 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 *******************************************************************************/ 16 17 #ifndef GPU_JIT_GEMM_GEN_GEMM_HPP 18 #define GPU_JIT_GEMM_GEN_GEMM_HPP 19 20 #include <assert.h> 21 #include <memory> 22 23 #include "common/c_types_map.hpp" 24 #include "common/gemm_utils.hpp" 25 #include "common/utils.hpp" 26 #include "gpu/compute/compute.hpp" 27 #include "gpu/compute/kernel.hpp" 28 #include "gpu/gemm/gpu_gemm.hpp" 29 #include "gpu/gpu_gemm_pd.hpp" 30 #include "gpu/jit/gemm/gen_gemm_kernel.hpp" 31 #include "gpu/primitive_conf.hpp" 32 33 namespace dnnl { 34 namespace impl { 35 namespace gpu { 36 namespace jit { 37 38 struct gen_gemm_t : public gpu_gemm_t { 39 using gpu_gemm_t::gpu_gemm_t; 40 41 struct pd_t : public gpu_gemm_pd_t { 42 using gpu_gemm_pd_t::gpu_gemm_pd_t; 43 44 DECLARE_COMMON_PD_T("jit:gemm:any", gen_gemm_t); 45 initdnnl::impl::gpu::jit::gen_gemm_t::pd_t46 status_t init(engine_t *engine) { 47 using namespace prop_kind; 48 using namespace data_type; 49 using namespace primitive_kind; 50 using smask_t = primitive_attr_t::skip_mask_t; 51 using arch_t = compute::gpu_arch_t; 52 53 assert(engine->kind() == engine_kind::gpu); 54 auto *compute_engine 55 = utils::downcast<compute::compute_engine_t *>(engine); 56 57 // LIMITATIONS: 58 // - runtime dims are not supported 59 // - bias only supported for f16 and f32 with same c_type. 60 // - postops only supported for f32. 61 bool ok = true; 62 63 auto attr_skip_mask = smask_t::oscale | smask_t::post_ops; 64 65 ok = set_default_formats(); 66 if (!ok) return status::unimplemented; 67 68 const auto d = desc(); 69 70 if (d->c_type() == s32) { 71 ok = ok && utils::one_of(d->a_type(), u8, s8) 72 && utils::one_of(d->b_type(), u8, s8) 73 && d->acc_type == d->c_type() 74 && attr()->zero_points_.defined(DNNL_ARG_SRC) 75 && attr()->zero_points_.defined(DNNL_ARG_WEIGHTS) 76 && (attr()->zero_points_.has_default_values( 77 DNNL_ARG_DST) 78 || !attr()->zero_points_.defined(DNNL_ARG_DST)); 79 80 int cmask = 0; 81 attr()->zero_points_.get( 82 DNNL_ARG_DST, nullptr, &cmask, nullptr); 83 ok &= utils::one_of(cmask, 0, 1 << 0, 1 << 1); 84 85 attr_skip_mask |= smask_t::zero_points_runtime; 86 } else { 87 ok = ok && utils::one_of(d->c_type(), f32, f16) 88 && d->a_type() == d->c_type() 89 && d->b_type() == d->c_type() 90 && d->acc_type == d->c_type(); 91 } 92 93 ok = ok && !has_blocks() && batch_dims() <= 2 94 && !utils::one_of(DNNL_RUNTIME_DIM_VAL, d->m(), d->n(), 95 d->k(), d->lda(), d->ldb(), d->ldc(), d->batch()) 96 && IMPLICATION(with_bias(), 97 (d->bias_type() == d->c_type()) 98 && utils::one_of( 99 bias_cmask(), 0, 1 << 0, 1 << 1)) 100 && compute_engine->mayiuse_ngen_kernels() 101 && attr()->has_default_values(attr_skip_mask) 102 && attr()->output_scales_.mask_ == 0 103 && attr()->post_ops_.len() <= 2 104 && IMPLICATION(attr()->post_ops_.len() == 1, 105 attr()->post_ops_.find(eltwise) != -1 106 || attr()->post_ops_.find(sum) != -1) 107 && IMPLICATION(attr()->post_ops_.len() == 2, 108 attr()->post_ops_.find(sum) == 0 109 && attr()->post_ops_.find(eltwise) == 1); 110 111 if (!ok) return status::unimplemented; 112 113 auto *dev_info = compute_engine->device_info(); 114 115 arch_ = dev_info->gpu_arch(); 116 117 ok &= utils::one_of(arch_, arch_t::gen9, arch_t::xe_lp); 118 119 if (!ok) return status::unimplemented; 120 121 eu_count_ = dev_info->eu_count(); 122 hw_threads_ = dev_info->hw_threads(); 123 124 attr_info_ = attr_info_t::create(attr()); 125 126 ok &= IMPLICATION(with_eltwise(), 127 jit_eltwise_injector_f32_is_supported( 128 attr_info()->eltwise_alg)); 129 130 if (!ok) return status::unimplemented; 131 132 return status::success; 133 } 134 set_default_formatsdnnl::impl::gpu::jit::gen_gemm_t::pd_t135 bool set_default_formats() { 136 return gpu_gemm_pd_t::set_default_formats(); 137 } 138 with_c_offsetdnnl::impl::gpu::jit::gen_gemm_t::pd_t139 bool with_c_offset() const { 140 return !attr()->zero_points_.has_default_values(DNNL_ARG_DST); 141 } 142 with_eltwisednnl::impl::gpu::jit::gen_gemm_t::pd_t143 bool with_eltwise() const { 144 return attr_info()->eltwise_alg != alg_kind::undef; 145 } 146 eltwise_algdnnl::impl::gpu::jit::gen_gemm_t::pd_t147 alg_kind_t eltwise_alg() const { return attr_info()->eltwise_alg; } 148 eltwise_alphadnnl::impl::gpu::jit::gen_gemm_t::pd_t149 float eltwise_alpha() const { return attr_info()->eltwise_alpha; } 150 eltwise_betadnnl::impl::gpu::jit::gen_gemm_t::pd_t151 float eltwise_beta() const { return attr_info()->eltwise_beta; } 152 eltwise_scalednnl::impl::gpu::jit::gen_gemm_t::pd_t153 float eltwise_scale() const { return attr_info()->eltwise_scale; } 154 alphadnnl::impl::gpu::jit::gen_gemm_t::pd_t155 float alpha() const { return attr()->output_scales_.scales_[0]; } 156 betadnnl::impl::gpu::jit::gen_gemm_t::pd_t157 float beta() const { 158 using namespace primitive_kind; 159 const auto &p = attr()->post_ops_; 160 return p.contain(sum, 0) ? p.entry_[0].sum.scale : 0.f; 161 } 162 with_biasdnnl::impl::gpu::jit::gen_gemm_t::pd_t163 bool with_bias() const { 164 return desc()->bias_type() != data_type::undef; 165 } 166 bias_cmaskdnnl::impl::gpu::jit::gen_gemm_t::pd_t167 int bias_cmask() const { 168 unsigned char to_cmask[4] = {0, 2, 1, 3}; 169 return with_bias() ? to_cmask[(desc()->bias_mask() >> 1) & 3] : -1; 170 } 171 batch_dimsdnnl::impl::gpu::jit::gen_gemm_t::pd_t172 int batch_dims() const { 173 return nstl::max(desc()->c_desc.ndims - 2, 0); 174 } 175 attr_infodnnl::impl::gpu::jit::gen_gemm_t::pd_t176 const attr_info_t *attr_info() const { return &attr_info_; } 177 178 size_t dyn_offset_a = 0; 179 size_t dyn_offset_b = 0; 180 size_t dyn_offset_c = 0; 181 size_t dyn_offset_co = 0; 182 int hw_threads_ = 0; 183 int eu_count_ = 0; 184 compute::gpu_arch_t arch_ = compute::gpu_arch_t::unknown; 185 186 attr_info_t attr_info_ = {}; 187 }; 188 initdnnl::impl::gpu::jit::gen_gemm_t189 status_t init(engine_t *engine) override { return init_nocopy(engine); } 190 init_nocopydnnl::impl::gpu::jit::gen_gemm_t191 status_t init_nocopy(engine_t *engine) { 192 using kernel_t = gen_gemm_nocopy_kernel_t; 193 194 int unroll_m, unroll_n; 195 auto batch = pd()->desc()->batch(); 196 int batch_dims = pd()->batch_dims(); 197 bool transa = (pd()->desc()->transa() == dnnl_trans); 198 bool transb = (pd()->desc()->transb() == dnnl_trans); 199 auto a_type = pd()->desc()->a_type(); 200 auto b_type = pd()->desc()->b_type(); 201 auto c_type = pd()->desc()->c_type(); 202 auto eltwise_alg = pd()->eltwise_alg(); 203 auto eltwise_alpha = pd()->eltwise_alpha(); 204 auto eltwise_beta = pd()->eltwise_beta(); 205 auto eltwise_scale = pd()->eltwise_scale(); 206 207 kernel_t::choose_unrolls(pd()->arch_, pd()->hw_threads_, transa, transb, 208 a_type, b_type, c_type, pd()->desc()->m(), pd()->desc()->n(), 209 pd()->desc()->k(), batch, unroll_m, unroll_n); 210 211 kernel_t kernel; 212 213 auto status = kernel.init(pd()->arch_, batch_dims, transa, transb, 214 pd()->with_c_offset(), pd()->with_bias(), eltwise_alg, 215 eltwise_alpha, eltwise_beta, eltwise_scale, a_type, b_type, 216 c_type, unroll_m, unroll_n); 217 if (status != status::success) return status; 218 219 create_kernel(engine, &nocopy_kernel_, kernel); 220 221 nocopy_info_ = kernel.driver_info(); 222 223 return status::success; 224 } 225 226 status_t execute(const gemm_exec_ctx_t &ctx) const override; 227 228 private: 229 status_t launch_nocopy(const gemm_exec_ctx_t &ctx, 230 compute::compute_stream_t *s, const memory_storage_t &a, 231 const memory_storage_t &b, const memory_storage_t &c, 232 const memory_storage_t &co, int64_t offset_a, int64_t offset_b, 233 int64_t offset_c, int32_t offset_co, int32_t lda, int32_t ldb, 234 int32_t ldc, int32_t m, int32_t n, int32_t k, float alpha, 235 float beta, int16_t ao, int16_t bo, int32_t cmask, 236 bool last_k_block) const; 237 238 compute::kernel_t nocopy_kernel_; 239 CommonDriverInfo nocopy_info_; 240 pddnnl::impl::gpu::jit::gen_gemm_t241 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } 242 }; 243 244 } // namespace jit 245 } // namespace gpu 246 } // namespace impl 247 } // namespace dnnl 248 #endif 249 250 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s 251