1 /*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "gpu/jit/gemm/gen_gemm.hpp"
18 #include "common/c_types_map.hpp"
19 #include "common/dnnl_traits.hpp"
20 #include "common/float16.hpp"
21 #include "common/math_utils.hpp"
22 #include "common/type_helpers.hpp"
23 #include "gpu/jit/gemm/gemm_walk_orders.hpp"
24 #include "gpu/jit/gemm/gen_gemm_kernel_common.hpp"
25
26 namespace dnnl {
27 namespace impl {
28 namespace gpu {
29 namespace jit {
30
launch_nocopy(const gemm_exec_ctx_t & ctx,compute::compute_stream_t * compute_stream,const memory_storage_t & a,const memory_storage_t & b,const memory_storage_t & c,const memory_storage_t & co,int64_t offset_a,int64_t offset_b,int64_t offset_c,int32_t offset_co,int32_t lda,int32_t ldb,int32_t ldc,int32_t m,int32_t n,int32_t k,int32_t k0,float alpha,float beta,int16_t ao,int16_t bo,int32_t cmask,bool last_k_block,bool swapab,bool disable_hilbert) const31 status_t gen_gemm_t::launch_nocopy(const gemm_exec_ctx_t &ctx,
32 compute::compute_stream_t *compute_stream, const memory_storage_t &a,
33 const memory_storage_t &b, const memory_storage_t &c,
34 const memory_storage_t &co, int64_t offset_a, int64_t offset_b,
35 int64_t offset_c, int32_t offset_co, int32_t lda, int32_t ldb,
36 int32_t ldc, int32_t m, int32_t n, int32_t k, int32_t k0, float alpha,
37 float beta, int16_t ao, int16_t bo, int32_t cmask, bool last_k_block,
38 bool swapab, bool disable_hilbert) const {
39
40 uint32_t flags = 0;
41 bool k_parallel = (nocopy_info_.kParallel || nocopy_info_.kParallelLocal);
42
43 auto stride_a0 = int32_t(pd()->desc()->stride_a(0));
44 auto stride_b0 = int32_t(pd()->desc()->stride_b(0));
45 auto stride_c0 = int32_t(pd()->desc()->stride_c(0));
46
47 auto stride_a1 = int32_t(pd()->desc()->stride_a(1));
48 auto stride_b1 = int32_t(pd()->desc()->stride_b(1));
49 auto stride_c1 = int32_t(pd()->desc()->stride_c(1));
50
51 if (swapab) {
52 std::swap(stride_a0, stride_b0);
53 std::swap(stride_a1, stride_b1);
54 }
55
56 if (!last_k_block) flags |= FlagNonfinalKBlock;
57 if (cmask & 1) flags |= FlagCOColumn;
58 if (cmask & 2) flags |= FlagCORow;
59
60 compute::kernel_arg_list_t arg_list;
61 int argn = 0;
62
63 arg_list.set(argn++, a);
64 arg_list.set(argn++, b);
65 arg_list.set(argn++, c);
66 arg_list.set(argn++, offset_a);
67 arg_list.set(argn++, offset_b);
68 arg_list.set(argn++, offset_c);
69 arg_list.set(argn++, lda);
70 arg_list.set(argn++, ldb);
71 arg_list.set(argn++, ldc);
72 arg_list.set(argn++, m);
73 arg_list.set(argn++, n);
74 arg_list.set(argn++, k);
75 arg_list.set(argn++, alpha);
76 arg_list.set(argn++, beta);
77 if (pd()->with_ab_zero_points()) {
78 uint32_t abo = uint16_t(-ao) | (uint16_t(-bo) << 16);
79 arg_list.set(argn++, abo);
80 }
81 if (pd()->with_c_zero_points() || pd()->with_bias()) {
82 arg_list.set(argn++, co);
83 arg_list.set(argn++, offset_co);
84 }
85 arg_list.set(argn++, flags);
86 if (k_parallel) arg_list.set(argn++, k0);
87
88 if (pd()->batch_dims() >= 1) {
89 arg_list.set(argn++, stride_a0);
90 arg_list.set(argn++, stride_b0);
91 arg_list.set(argn++, stride_c0);
92 }
93
94 if (pd()->batch_dims() >= 2) {
95 auto batchSize1 = uint32_t(pd()->desc()->c_desc.dims[1]);
96 uint32_t recipBatchSize1 = (uint32_t)utils::div_up(
97 uint64_t(0x100000000) << math::ilog2q(batchSize1), batchSize1);
98 arg_list.set(argn++, stride_a1);
99 arg_list.set(argn++, stride_b1);
100 arg_list.set(argn++, stride_c1);
101 arg_list.set(argn++, batchSize1);
102 arg_list.set(argn++, recipBatchSize1);
103 }
104
105 size_t gws[3] = {0, 0, 1};
106
107 gws[0] = utils::div_up(m, nocopy_info_.unroll[0]);
108 gws[1] = utils::div_up(n, nocopy_info_.unroll[1]);
109 gws[2] = k_parallel ? nstl::max(1, utils::div_up(k, k0))
110 : pd()->desc()->batch();
111
112 size_t lws[3] = {size_t(nocopy_info_.wg[0]), size_t(nocopy_info_.wg[1]),
113 size_t(nocopy_info_.wg[2])};
114
115 if (nocopy_info_.isNMK()) {
116 std::swap(lws[0], lws[1]);
117 std::swap(gws[0], gws[1]);
118 }
119
120 if (nocopy_info_.fusedEUs && (lws[0] > 1))
121 gws[0] = utils::rnd_up(gws[0], 2);
122
123 int last_non_1 = 2;
124 for (; last_non_1 >= 0 && (gws[last_non_1] == 1 || lws[last_non_1] == 1);
125 last_non_1--)
126 ;
127
128 for (int d = 0; d < 2; d++) {
129 if (nocopy_info_.fixedWG || (gws[d] > lws[d]))
130 gws[d] = utils::rnd_up(gws[d], lws[d]);
131 else {
132 // Workaround to avoid local ID reordering until reqd_walk_group_order implemented in UMD.
133 if (pd()->arch_ >= compute::gpu_arch_t::xe_hp && d < last_non_1)
134 gws[d] = utils::rnd_up_pow2(gws[d]);
135 lws[d] = gws[d];
136 }
137 }
138
139 lws[1] *= nocopy_info_.wgExpand;
140 gws[1] *= nocopy_info_.wgExpand;
141
142 gemm_linear_order_args(arg_list, argn, lws, gws, m, n, disable_hilbert,
143 nocopy_info_, pd()->dev_info_);
144
145 lws[0] *= nocopy_info_.subgroupSize;
146 gws[0] *= nocopy_info_.subgroupSize;
147
148 auto nd_range = compute::nd_range_t(gws, lws);
149 return parallel_for(ctx, nd_range, nocopy_kernel_, arg_list);
150 }
151
execute(const gemm_exec_ctx_t & ctx) const152 status_t gen_gemm_t::execute(const gemm_exec_ctx_t &ctx) const {
153 auto a_type = pd()->desc()->a_type();
154 auto b_type = pd()->desc()->b_type();
155 auto c_type = pd()->desc()->c_type();
156
157 auto *compute_stream
158 = utils::downcast<compute::compute_stream_t *>(ctx.stream());
159
160 const bool swapab = pd()->swap_ab();
161
162 const auto m = swapab ? pd()->desc()->n() : pd()->desc()->m();
163 const auto n = swapab ? pd()->desc()->m() : pd()->desc()->n();
164 auto k = pd()->desc()->k();
165
166 const bool transa = swapab ? (pd()->desc()->transb() == dnnl_notrans)
167 : (pd()->desc()->transa() == dnnl_trans);
168 const bool transb = swapab ? false : (pd()->desc()->transb() == dnnl_trans);
169
170 const auto lda = swapab ? pd()->desc()->ldb() : pd()->desc()->lda();
171 const auto ldb = swapab ? pd()->desc()->lda() : pd()->desc()->ldb();
172 auto ldc = pd()->desc()->ldc();
173
174 auto alpha = pd()->alpha();
175 auto beta = pd()->beta();
176
177 bool k_parallel = nocopy_info_.kParallel || nocopy_info_.kParallelLocal;
178
179 auto &a = swapab ? GEMM_CTX_ARG_STORAGE(a) : GEMM_CTX_ARG_STORAGE(b);
180 auto &b = swapab ? GEMM_CTX_ARG_STORAGE(b) : GEMM_CTX_ARG_STORAGE(a);
181 auto &c = GEMM_CTX_ARG_STORAGE(c);
182 auto &c_zp = GEMM_CTX_ARG_STORAGE(c_zero_point);
183 auto &bias = GEMM_CTX_ARG_STORAGE(bias);
184 auto *co = &c_zp;
185
186 size_t off_a0
187 = a.offset() / types::data_type_size(a_type) + pd()->dyn_offset_a;
188 size_t off_b0
189 = b.offset() / types::data_type_size(b_type) + pd()->dyn_offset_b;
190 size_t off_c0
191 = c.offset() / types::data_type_size(c_type) + pd()->dyn_offset_c;
192 size_t off_co0 = 0;
193
194 int16_t ao = 0, bo = 0;
195 int cmask = 0;
196
197 if (c_type == data_type::s32) {
198 off_co0 = co->offset() / types::data_type_size(c_type)
199 + pd()->dyn_offset_co;
200 } else if (pd()->with_bias()) {
201 off_co0 = bias.offset() / types::data_type_size(c_type);
202 co = &bias;
203 cmask = pd()->bias_cmask();
204 off_co0 = bias.offset() / types::data_type_size(c_type);
205 }
206
207 if (pd()->with_ab_zero_points()) {
208 const int *ao_i32 = nullptr;
209 const int *bo_i32 = nullptr;
210 pd()->attr()->zero_points_.get(DNNL_ARG_SRC, nullptr, nullptr, &ao_i32);
211 pd()->attr()->zero_points_.get(
212 DNNL_ARG_WEIGHTS, nullptr, nullptr, &bo_i32);
213 ao = *ao_i32;
214 bo = *bo_i32;
215 }
216 if (pd()->with_c_zero_points())
217 pd()->attr()->zero_points_.get(DNNL_ARG_DST, nullptr, &cmask, nullptr);
218
219 status_t status;
220
221 auto block_m = nocopy_info_.blocking[0];
222 auto block_n = nocopy_info_.blocking[1];
223 auto block_k = nocopy_info_.blocking[2];
224
225 bool disable_hilbert = (k <= 64) && nocopy_info_.isHilbert();
226 if (disable_hilbert) {
227 block_m = nocopy_info_.blockingAlt[0];
228 block_n = nocopy_info_.blockingAlt[1];
229 }
230
231 if (!utils::one_of(pd()->desc()->c_type(), data_type::f32, data_type::f16))
232 block_k = k;
233
234 block_m = utils::rnd_up(
235 block_m, nocopy_info_.wg[0] * nocopy_info_.unroll[0]);
236 block_n = utils::rnd_up(
237 block_n, nocopy_info_.wg[1] * nocopy_info_.unroll[1]);
238 block_k = utils::rnd_up(block_k, nocopy_info_.unroll[2]);
239
240 int32_t k0 = 1;
241 if (k_parallel) {
242 k0 = block_k;
243 block_k = k;
244
245 if (beta != 1.0f && (k > k0 * nocopy_info_.wg[2])) {
246 status = launch_nocopy(ctx, compute_stream, a, b, c, *co, off_a0,
247 off_b0, off_c0, int32_t(off_co0), lda, ldb, ldc, m, n, 0, 1,
248 1.0f, beta, 0, 0, 0, false, swapab, true);
249 beta = 1.0f;
250 }
251 }
252
253 for (int64_t Bk = 0; Bk < k; Bk += block_k) {
254 int64_t size_k = k - Bk;
255 bool last_k_block = (size_k <= block_k);
256 if (!last_k_block) size_k = block_k;
257
258 for (int64_t Bm = 0; Bm < m; Bm += block_m) {
259 int64_t size_m = m - Bm;
260 if (size_m > block_m) size_m = block_m;
261
262 auto off_a_src
263 = off_a0 + (!transa ? (Bm + Bk * lda) : (Bk + Bm * lda));
264
265 for (int64_t Bn = 0; Bn < n; Bn += block_n) {
266 int64_t size_n = n - Bn;
267 if (size_n > block_n) size_n = block_n;
268
269 auto off_b_src = off_b0
270 + (!transb ? (Bk + Bn * ldb) : (Bn + Bk * ldb));
271
272 auto off_c = off_c0 + Bm + Bn * ldc;
273 auto off_co = int32_t(off_co0);
274 if (cmask & 1) off_co += Bn;
275 if (cmask & 2) off_co += Bm;
276
277 float eff_beta = (Bk == 0) ? beta : 1.0f;
278 status = launch_nocopy(ctx, compute_stream, a, b, c, *co,
279 off_a_src, off_b_src, off_c, off_co, lda, ldb, ldc,
280 size_m, size_n, size_k, k0, alpha, eff_beta, ao, bo,
281 cmask, last_k_block, swapab, disable_hilbert);
282
283 if (status) return status;
284 }
285 }
286 }
287
288 return status::success;
289 }
290
291 } // namespace jit
292 } // namespace gpu
293 } // namespace impl
294 } // namespace dnnl
295
296 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
297