1 /*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include "gpu/jit/gemm/gen_gemm.hpp"
18 #include "common/c_types_map.hpp"
19 #include "common/dnnl_traits.hpp"
20 #include "common/float16.hpp"
21 #include "common/math_utils.hpp"
22 #include "common/type_helpers.hpp"
23 #include "gpu/jit/gemm/gemm_walk_orders.hpp"
24 #include "gpu/jit/gemm/gen_gemm_kernel_common.hpp"
25 
26 namespace dnnl {
27 namespace impl {
28 namespace gpu {
29 namespace jit {
30 
launch_nocopy(const gemm_exec_ctx_t & ctx,compute::compute_stream_t * compute_stream,const memory_storage_t & a,const memory_storage_t & b,const memory_storage_t & c,const memory_storage_t & co,int64_t offset_a,int64_t offset_b,int64_t offset_c,int32_t offset_co,int32_t lda,int32_t ldb,int32_t ldc,int32_t m,int32_t n,int32_t k,int32_t k0,float alpha,float beta,int16_t ao,int16_t bo,int32_t cmask,bool last_k_block,bool swapab,bool disable_hilbert) const31 status_t gen_gemm_t::launch_nocopy(const gemm_exec_ctx_t &ctx,
32         compute::compute_stream_t *compute_stream, const memory_storage_t &a,
33         const memory_storage_t &b, const memory_storage_t &c,
34         const memory_storage_t &co, int64_t offset_a, int64_t offset_b,
35         int64_t offset_c, int32_t offset_co, int32_t lda, int32_t ldb,
36         int32_t ldc, int32_t m, int32_t n, int32_t k, int32_t k0, float alpha,
37         float beta, int16_t ao, int16_t bo, int32_t cmask, bool last_k_block,
38         bool swapab, bool disable_hilbert) const {
39 
40     uint32_t flags = 0;
41     bool k_parallel = (nocopy_info_.kParallel || nocopy_info_.kParallelLocal);
42 
43     auto stride_a0 = int32_t(pd()->desc()->stride_a(0));
44     auto stride_b0 = int32_t(pd()->desc()->stride_b(0));
45     auto stride_c0 = int32_t(pd()->desc()->stride_c(0));
46 
47     auto stride_a1 = int32_t(pd()->desc()->stride_a(1));
48     auto stride_b1 = int32_t(pd()->desc()->stride_b(1));
49     auto stride_c1 = int32_t(pd()->desc()->stride_c(1));
50 
51     if (swapab) {
52         std::swap(stride_a0, stride_b0);
53         std::swap(stride_a1, stride_b1);
54     }
55 
56     if (!last_k_block) flags |= FlagNonfinalKBlock;
57     if (cmask & 1) flags |= FlagCOColumn;
58     if (cmask & 2) flags |= FlagCORow;
59 
60     compute::kernel_arg_list_t arg_list;
61     int argn = 0;
62 
63     arg_list.set(argn++, a);
64     arg_list.set(argn++, b);
65     arg_list.set(argn++, c);
66     arg_list.set(argn++, offset_a);
67     arg_list.set(argn++, offset_b);
68     arg_list.set(argn++, offset_c);
69     arg_list.set(argn++, lda);
70     arg_list.set(argn++, ldb);
71     arg_list.set(argn++, ldc);
72     arg_list.set(argn++, m);
73     arg_list.set(argn++, n);
74     arg_list.set(argn++, k);
75     arg_list.set(argn++, alpha);
76     arg_list.set(argn++, beta);
77     if (pd()->with_ab_zero_points()) {
78         uint32_t abo = uint16_t(-ao) | (uint16_t(-bo) << 16);
79         arg_list.set(argn++, abo);
80     }
81     if (pd()->with_c_zero_points() || pd()->with_bias()) {
82         arg_list.set(argn++, co);
83         arg_list.set(argn++, offset_co);
84     }
85     arg_list.set(argn++, flags);
86     if (k_parallel) arg_list.set(argn++, k0);
87 
88     if (pd()->batch_dims() >= 1) {
89         arg_list.set(argn++, stride_a0);
90         arg_list.set(argn++, stride_b0);
91         arg_list.set(argn++, stride_c0);
92     }
93 
94     if (pd()->batch_dims() >= 2) {
95         auto batchSize1 = uint32_t(pd()->desc()->c_desc.dims[1]);
96         uint32_t recipBatchSize1 = (uint32_t)utils::div_up(
97                 uint64_t(0x100000000) << math::ilog2q(batchSize1), batchSize1);
98         arg_list.set(argn++, stride_a1);
99         arg_list.set(argn++, stride_b1);
100         arg_list.set(argn++, stride_c1);
101         arg_list.set(argn++, batchSize1);
102         arg_list.set(argn++, recipBatchSize1);
103     }
104 
105     size_t gws[3] = {0, 0, 1};
106 
107     gws[0] = utils::div_up(m, nocopy_info_.unroll[0]);
108     gws[1] = utils::div_up(n, nocopy_info_.unroll[1]);
109     gws[2] = k_parallel ? nstl::max(1, utils::div_up(k, k0))
110                         : pd()->desc()->batch();
111 
112     size_t lws[3] = {size_t(nocopy_info_.wg[0]), size_t(nocopy_info_.wg[1]),
113             size_t(nocopy_info_.wg[2])};
114 
115     if (nocopy_info_.isNMK()) {
116         std::swap(lws[0], lws[1]);
117         std::swap(gws[0], gws[1]);
118     }
119 
120     if (nocopy_info_.fusedEUs && (lws[0] > 1))
121         gws[0] = utils::rnd_up(gws[0], 2);
122 
123     int last_non_1 = 2;
124     for (; last_non_1 >= 0 && (gws[last_non_1] == 1 || lws[last_non_1] == 1);
125             last_non_1--)
126         ;
127 
128     for (int d = 0; d < 2; d++) {
129         if (nocopy_info_.fixedWG || (gws[d] > lws[d]))
130             gws[d] = utils::rnd_up(gws[d], lws[d]);
131         else {
132             // Workaround to avoid local ID reordering until reqd_walk_group_order implemented in UMD.
133             if (pd()->arch_ >= compute::gpu_arch_t::xe_hp && d < last_non_1)
134                 gws[d] = utils::rnd_up_pow2(gws[d]);
135             lws[d] = gws[d];
136         }
137     }
138 
139     lws[1] *= nocopy_info_.wgExpand;
140     gws[1] *= nocopy_info_.wgExpand;
141 
142     gemm_linear_order_args(arg_list, argn, lws, gws, m, n, disable_hilbert,
143             nocopy_info_, pd()->dev_info_);
144 
145     lws[0] *= nocopy_info_.subgroupSize;
146     gws[0] *= nocopy_info_.subgroupSize;
147 
148     auto nd_range = compute::nd_range_t(gws, lws);
149     return parallel_for(ctx, nd_range, nocopy_kernel_, arg_list);
150 }
151 
execute(const gemm_exec_ctx_t & ctx) const152 status_t gen_gemm_t::execute(const gemm_exec_ctx_t &ctx) const {
153     auto a_type = pd()->desc()->a_type();
154     auto b_type = pd()->desc()->b_type();
155     auto c_type = pd()->desc()->c_type();
156 
157     auto *compute_stream
158             = utils::downcast<compute::compute_stream_t *>(ctx.stream());
159 
160     const bool swapab = pd()->swap_ab();
161 
162     const auto m = swapab ? pd()->desc()->n() : pd()->desc()->m();
163     const auto n = swapab ? pd()->desc()->m() : pd()->desc()->n();
164     auto k = pd()->desc()->k();
165 
166     const bool transa = swapab ? (pd()->desc()->transb() == dnnl_notrans)
167                                : (pd()->desc()->transa() == dnnl_trans);
168     const bool transb = swapab ? false : (pd()->desc()->transb() == dnnl_trans);
169 
170     const auto lda = swapab ? pd()->desc()->ldb() : pd()->desc()->lda();
171     const auto ldb = swapab ? pd()->desc()->lda() : pd()->desc()->ldb();
172     auto ldc = pd()->desc()->ldc();
173 
174     auto alpha = pd()->alpha();
175     auto beta = pd()->beta();
176 
177     bool k_parallel = nocopy_info_.kParallel || nocopy_info_.kParallelLocal;
178 
179     auto &a = swapab ? GEMM_CTX_ARG_STORAGE(a) : GEMM_CTX_ARG_STORAGE(b);
180     auto &b = swapab ? GEMM_CTX_ARG_STORAGE(b) : GEMM_CTX_ARG_STORAGE(a);
181     auto &c = GEMM_CTX_ARG_STORAGE(c);
182     auto &c_zp = GEMM_CTX_ARG_STORAGE(c_zero_point);
183     auto &bias = GEMM_CTX_ARG_STORAGE(bias);
184     auto *co = &c_zp;
185 
186     size_t off_a0
187             = a.offset() / types::data_type_size(a_type) + pd()->dyn_offset_a;
188     size_t off_b0
189             = b.offset() / types::data_type_size(b_type) + pd()->dyn_offset_b;
190     size_t off_c0
191             = c.offset() / types::data_type_size(c_type) + pd()->dyn_offset_c;
192     size_t off_co0 = 0;
193 
194     int16_t ao = 0, bo = 0;
195     int cmask = 0;
196 
197     if (c_type == data_type::s32) {
198         off_co0 = co->offset() / types::data_type_size(c_type)
199                 + pd()->dyn_offset_co;
200     } else if (pd()->with_bias()) {
201         off_co0 = bias.offset() / types::data_type_size(c_type);
202         co = &bias;
203         cmask = pd()->bias_cmask();
204         off_co0 = bias.offset() / types::data_type_size(c_type);
205     }
206 
207     if (pd()->with_ab_zero_points()) {
208         const int *ao_i32 = nullptr;
209         const int *bo_i32 = nullptr;
210         pd()->attr()->zero_points_.get(DNNL_ARG_SRC, nullptr, nullptr, &ao_i32);
211         pd()->attr()->zero_points_.get(
212                 DNNL_ARG_WEIGHTS, nullptr, nullptr, &bo_i32);
213         ao = *ao_i32;
214         bo = *bo_i32;
215     }
216     if (pd()->with_c_zero_points())
217         pd()->attr()->zero_points_.get(DNNL_ARG_DST, nullptr, &cmask, nullptr);
218 
219     status_t status;
220 
221     auto block_m = nocopy_info_.blocking[0];
222     auto block_n = nocopy_info_.blocking[1];
223     auto block_k = nocopy_info_.blocking[2];
224 
225     bool disable_hilbert = (k <= 64) && nocopy_info_.isHilbert();
226     if (disable_hilbert) {
227         block_m = nocopy_info_.blockingAlt[0];
228         block_n = nocopy_info_.blockingAlt[1];
229     }
230 
231     if (!utils::one_of(pd()->desc()->c_type(), data_type::f32, data_type::f16))
232         block_k = k;
233 
234     block_m = utils::rnd_up(
235             block_m, nocopy_info_.wg[0] * nocopy_info_.unroll[0]);
236     block_n = utils::rnd_up(
237             block_n, nocopy_info_.wg[1] * nocopy_info_.unroll[1]);
238     block_k = utils::rnd_up(block_k, nocopy_info_.unroll[2]);
239 
240     int32_t k0 = 1;
241     if (k_parallel) {
242         k0 = block_k;
243         block_k = k;
244 
245         if (beta != 1.0f && (k > k0 * nocopy_info_.wg[2])) {
246             status = launch_nocopy(ctx, compute_stream, a, b, c, *co, off_a0,
247                     off_b0, off_c0, int32_t(off_co0), lda, ldb, ldc, m, n, 0, 1,
248                     1.0f, beta, 0, 0, 0, false, swapab, true);
249             beta = 1.0f;
250         }
251     }
252 
253     for (int64_t Bk = 0; Bk < k; Bk += block_k) {
254         int64_t size_k = k - Bk;
255         bool last_k_block = (size_k <= block_k);
256         if (!last_k_block) size_k = block_k;
257 
258         for (int64_t Bm = 0; Bm < m; Bm += block_m) {
259             int64_t size_m = m - Bm;
260             if (size_m > block_m) size_m = block_m;
261 
262             auto off_a_src
263                     = off_a0 + (!transa ? (Bm + Bk * lda) : (Bk + Bm * lda));
264 
265             for (int64_t Bn = 0; Bn < n; Bn += block_n) {
266                 int64_t size_n = n - Bn;
267                 if (size_n > block_n) size_n = block_n;
268 
269                 auto off_b_src = off_b0
270                         + (!transb ? (Bk + Bn * ldb) : (Bn + Bk * ldb));
271 
272                 auto off_c = off_c0 + Bm + Bn * ldc;
273                 auto off_co = int32_t(off_co0);
274                 if (cmask & 1) off_co += Bn;
275                 if (cmask & 2) off_co += Bm;
276 
277                 float eff_beta = (Bk == 0) ? beta : 1.0f;
278                 status = launch_nocopy(ctx, compute_stream, a, b, c, *co,
279                         off_a_src, off_b_src, off_c, off_co, lda, ldb, ldc,
280                         size_m, size_n, size_k, k0, alpha, eff_beta, ao, bo,
281                         cmask, last_k_block, swapab, disable_hilbert);
282 
283                 if (status) return status;
284             }
285         }
286     }
287 
288     return status::success;
289 }
290 
291 } // namespace jit
292 } // namespace gpu
293 } // namespace impl
294 } // namespace dnnl
295 
296 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
297