1 /*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #ifndef GPU_JIT_GEMM_GEN_GEMM_HPP
18 #define GPU_JIT_GEMM_GEN_GEMM_HPP
19 
20 #include <assert.h>
21 #include <memory>
22 
23 #include "common/c_types_map.hpp"
24 #include "common/gemm_utils.hpp"
25 #include "common/utils.hpp"
26 #include "gpu/compute/compute.hpp"
27 #include "gpu/compute/kernel.hpp"
28 #include "gpu/gemm/gpu_gemm.hpp"
29 #include "gpu/gpu_gemm_pd.hpp"
30 #include "gpu/jit/gemm/gen_gemm_kernel.hpp"
31 #include "gpu/primitive_conf.hpp"
32 
33 namespace dnnl {
34 namespace impl {
35 namespace gpu {
36 namespace jit {
37 
38 struct gen_gemm_t : public gpu_gemm_t {
39     using gpu_gemm_t::gpu_gemm_t;
40 
41     struct pd_t : public gpu_gemm_pd_t {
42         using gpu_gemm_pd_t::gpu_gemm_pd_t;
43 
44         DECLARE_COMMON_PD_T("jit:gemm:any", gen_gemm_t);
45 
initdnnl::impl::gpu::jit::gen_gemm_t::pd_t46         status_t init(engine_t *engine) {
47             using namespace prop_kind;
48             using namespace data_type;
49             using namespace primitive_kind;
50             using smask_t = primitive_attr_t::skip_mask_t;
51             using arch_t = compute::gpu_arch_t;
52 
53             assert(engine->kind() == engine_kind::gpu);
54             auto *compute_engine
55                     = utils::downcast<compute::compute_engine_t *>(engine);
56 
57             // LIMITATIONS:
58             // - runtime dims are not supported
59             // - bias only supported for f16 and f32 with same c_type.
60             // - postops only supported for f32.
61             bool ok = true;
62 
63             auto attr_skip_mask = smask_t::oscale | smask_t::post_ops;
64 
65             ok = set_default_formats();
66             if (!ok) return status::unimplemented;
67 
68             const auto d = desc();
69 
70             if (d->c_type() == s32) {
71                 ok = ok && utils::one_of(d->a_type(), u8, s8)
72                         && utils::one_of(d->b_type(), u8, s8)
73                         && d->acc_type == d->c_type()
74                         && attr()->zero_points_.defined(DNNL_ARG_SRC)
75                         && attr()->zero_points_.defined(DNNL_ARG_WEIGHTS)
76                         && (attr()->zero_points_.has_default_values(
77                                     DNNL_ARG_DST)
78                                 || !attr()->zero_points_.defined(DNNL_ARG_DST));
79 
80                 int cmask = 0;
81                 attr()->zero_points_.get(
82                         DNNL_ARG_DST, nullptr, &cmask, nullptr);
83                 ok &= utils::one_of(cmask, 0, 1 << 0, 1 << 1);
84 
85                 attr_skip_mask |= smask_t::zero_points_runtime;
86             } else {
87                 ok = ok && utils::one_of(d->c_type(), f32, f16)
88                         && d->a_type() == d->c_type()
89                         && d->b_type() == d->c_type()
90                         && d->acc_type == d->c_type();
91             }
92 
93             ok = ok && !has_blocks() && batch_dims() <= 2
94                     && !utils::one_of(DNNL_RUNTIME_DIM_VAL, d->m(), d->n(),
95                             d->k(), d->lda(), d->ldb(), d->ldc(), d->batch())
96                     && IMPLICATION(with_bias(),
97                             (d->bias_type() == d->c_type())
98                                     && utils::one_of(
99                                             bias_cmask(), 0, 1 << 0, 1 << 1))
100                     && compute_engine->mayiuse_ngen_kernels()
101                     && attr()->has_default_values(attr_skip_mask)
102                     && attr()->output_scales_.mask_ == 0
103                     && attr()->post_ops_.len() <= 2
104                     && IMPLICATION(attr()->post_ops_.len() == 1,
105                             attr()->post_ops_.find(eltwise) != -1
106                                     || attr()->post_ops_.find(sum) != -1)
107                     && IMPLICATION(attr()->post_ops_.len() == 2,
108                             attr()->post_ops_.find(sum) == 0
109                                     && attr()->post_ops_.find(eltwise) == 1);
110 
111             if (!ok) return status::unimplemented;
112 
113             auto *dev_info = compute_engine->device_info();
114 
115             arch_ = dev_info->gpu_arch();
116 
117             ok &= utils::one_of(arch_, arch_t::gen9, arch_t::xe_lp);
118 
119             if (!ok) return status::unimplemented;
120 
121             eu_count_ = dev_info->eu_count();
122             hw_threads_ = dev_info->hw_threads();
123 
124             attr_info_ = attr_info_t::create(attr());
125 
126             ok &= IMPLICATION(with_eltwise(),
127                     jit_eltwise_injector_f32_is_supported(
128                             attr_info()->eltwise_alg));
129 
130             if (!ok) return status::unimplemented;
131 
132             return status::success;
133         }
134 
set_default_formatsdnnl::impl::gpu::jit::gen_gemm_t::pd_t135         bool set_default_formats() {
136             return gpu_gemm_pd_t::set_default_formats();
137         }
138 
with_c_offsetdnnl::impl::gpu::jit::gen_gemm_t::pd_t139         bool with_c_offset() const {
140             return !attr()->zero_points_.has_default_values(DNNL_ARG_DST);
141         }
142 
with_eltwisednnl::impl::gpu::jit::gen_gemm_t::pd_t143         bool with_eltwise() const {
144             return attr_info()->eltwise_alg != alg_kind::undef;
145         }
146 
eltwise_algdnnl::impl::gpu::jit::gen_gemm_t::pd_t147         alg_kind_t eltwise_alg() const { return attr_info()->eltwise_alg; }
148 
eltwise_alphadnnl::impl::gpu::jit::gen_gemm_t::pd_t149         float eltwise_alpha() const { return attr_info()->eltwise_alpha; }
150 
eltwise_betadnnl::impl::gpu::jit::gen_gemm_t::pd_t151         float eltwise_beta() const { return attr_info()->eltwise_beta; }
152 
eltwise_scalednnl::impl::gpu::jit::gen_gemm_t::pd_t153         float eltwise_scale() const { return attr_info()->eltwise_scale; }
154 
alphadnnl::impl::gpu::jit::gen_gemm_t::pd_t155         float alpha() const { return attr()->output_scales_.scales_[0]; }
156 
betadnnl::impl::gpu::jit::gen_gemm_t::pd_t157         float beta() const {
158             using namespace primitive_kind;
159             const auto &p = attr()->post_ops_;
160             return p.contain(sum, 0) ? p.entry_[0].sum.scale : 0.f;
161         }
162 
with_biasdnnl::impl::gpu::jit::gen_gemm_t::pd_t163         bool with_bias() const {
164             return desc()->bias_type() != data_type::undef;
165         }
166 
bias_cmaskdnnl::impl::gpu::jit::gen_gemm_t::pd_t167         int bias_cmask() const {
168             unsigned char to_cmask[4] = {0, 2, 1, 3};
169             return with_bias() ? to_cmask[(desc()->bias_mask() >> 1) & 3] : -1;
170         }
171 
batch_dimsdnnl::impl::gpu::jit::gen_gemm_t::pd_t172         int batch_dims() const {
173             return nstl::max(desc()->c_desc.ndims - 2, 0);
174         }
175 
attr_infodnnl::impl::gpu::jit::gen_gemm_t::pd_t176         const attr_info_t *attr_info() const { return &attr_info_; }
177 
178         size_t dyn_offset_a = 0;
179         size_t dyn_offset_b = 0;
180         size_t dyn_offset_c = 0;
181         size_t dyn_offset_co = 0;
182         int hw_threads_ = 0;
183         int eu_count_ = 0;
184         compute::gpu_arch_t arch_ = compute::gpu_arch_t::unknown;
185 
186         attr_info_t attr_info_ = {};
187     };
188 
initdnnl::impl::gpu::jit::gen_gemm_t189     status_t init(engine_t *engine) override { return init_nocopy(engine); }
190 
init_nocopydnnl::impl::gpu::jit::gen_gemm_t191     status_t init_nocopy(engine_t *engine) {
192         using kernel_t = gen_gemm_nocopy_kernel_t;
193 
194         int unroll_m, unroll_n;
195         auto batch = pd()->desc()->batch();
196         int batch_dims = pd()->batch_dims();
197         bool transa = (pd()->desc()->transa() == dnnl_trans);
198         bool transb = (pd()->desc()->transb() == dnnl_trans);
199         auto a_type = pd()->desc()->a_type();
200         auto b_type = pd()->desc()->b_type();
201         auto c_type = pd()->desc()->c_type();
202         auto eltwise_alg = pd()->eltwise_alg();
203         auto eltwise_alpha = pd()->eltwise_alpha();
204         auto eltwise_beta = pd()->eltwise_beta();
205         auto eltwise_scale = pd()->eltwise_scale();
206 
207         kernel_t::choose_unrolls(pd()->arch_, pd()->hw_threads_, transa, transb,
208                 a_type, b_type, c_type, pd()->desc()->m(), pd()->desc()->n(),
209                 pd()->desc()->k(), batch, unroll_m, unroll_n);
210 
211         kernel_t kernel;
212 
213         auto status = kernel.init(pd()->arch_, batch_dims, transa, transb,
214                 pd()->with_c_offset(), pd()->with_bias(), eltwise_alg,
215                 eltwise_alpha, eltwise_beta, eltwise_scale, a_type, b_type,
216                 c_type, unroll_m, unroll_n);
217         if (status != status::success) return status;
218 
219         create_kernel(engine, &nocopy_kernel_, kernel);
220 
221         nocopy_info_ = kernel.driver_info();
222 
223         return status::success;
224     }
225 
226     status_t execute(const gemm_exec_ctx_t &ctx) const override;
227 
228 private:
229     status_t launch_nocopy(const gemm_exec_ctx_t &ctx,
230             compute::compute_stream_t *s, const memory_storage_t &a,
231             const memory_storage_t &b, const memory_storage_t &c,
232             const memory_storage_t &co, int64_t offset_a, int64_t offset_b,
233             int64_t offset_c, int32_t offset_co, int32_t lda, int32_t ldb,
234             int32_t ldc, int32_t m, int32_t n, int32_t k, float alpha,
235             float beta, int16_t ao, int16_t bo, int32_t cmask,
236             bool last_k_block) const;
237 
238     compute::kernel_t nocopy_kernel_;
239     CommonDriverInfo nocopy_info_;
240 
pddnnl::impl::gpu::jit::gen_gemm_t241     const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
242 };
243 
244 } // namespace jit
245 } // namespace gpu
246 } // namespace impl
247 } // namespace dnnl
248 #endif
249 
250 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
251