1 /*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <atomic>
18
19 #include <assert.h>
20 #include <float.h>
21 #include <math.h>
22
23 #include "common/c_types_map.hpp"
24 #include "common/dnnl_thread.hpp"
25 #include "common/type_helpers.hpp"
26 #include "common/utils.hpp"
27
28 #include "cpu/cpu_primitive.hpp"
29
30 #include "cpu/gemm/gemm.hpp"
31
32 #include "cpu/binary_injector_utils.hpp"
33 #include "cpu/matmul/gemm_f32_matmul.hpp"
34 #include "cpu/matmul/matmul_utils.hpp"
35
36 namespace dnnl {
37 namespace impl {
38 namespace cpu {
39 namespace matmul {
40
41 using namespace data_type;
42
init(engine_t * engine)43 status_t gemm_f32_matmul_t::pd_t::init(engine_t *engine) {
44 auto check_bias = [&]() -> bool {
45 return !with_bias()
46 || (weights_md(1)->data_type == f32 && is_bias_1xN());
47 };
48
49 bool ok = src_md()->data_type == src_type
50 && weights_md()->data_type == weights_type
51 && desc()->accum_data_type == acc_type
52 && dst_md()->data_type == dst_type && check_bias()
53 && attr()->has_default_values(
54 primitive_attr_t::skip_mask_t::oscale_runtime
55 | primitive_attr_t::skip_mask_t::post_ops
56 | primitive_attr_t::skip_mask_t::sum_dt,
57 dst_type)
58 && attr()->post_ops_.check_sum_consistent_dt(dst_type)
59 && set_default_formats()
60 && attr_.set_default_formats(dst_md(0)) == status::success
61 && gemm_based::check_gemm_compatible_formats(*this);
62
63 if (!ok) return status::unimplemented;
64
65 if (!has_runtime_dims_or_strides())
66 params_.can_fuse_src_batch_dims_
67 = matmul_helper_t(src_md(), weights_md(), dst_md())
68 .can_fuse_src_batch_dims();
69
70 CHECK(check_and_configure_attributes());
71
72 nthr_ = dnnl_get_max_threads();
73 gemm_based::book_acc_scratchpad(*this, params_, sizeof(acc_data_t), nthr_);
74
75 return status::success;
76 }
77
should_gemm_execute_sum_po(const gemm_based::params_t & params,data_type_t dst_dt)78 static bool should_gemm_execute_sum_po(
79 const gemm_based::params_t ¶ms, data_type_t dst_dt) noexcept {
80 const auto &po = params.pp_attr_.post_ops_;
81 static constexpr int sum_idx = 0;
82 return po.len() > 0 && po.contain(primitive_kind::sum, sum_idx)
83 && params.gemm_applies_output_scales_
84 && po.entry_[sum_idx].sum.zero_point == 0
85 && utils::one_of(
86 po.entry_[sum_idx].sum.dt, dst_dt, data_type::undef);
87 }
88
check_and_configure_attributes()89 status_t gemm_f32_matmul_t::pd_t::check_and_configure_attributes() {
90 auto check_attr_oscale = [&]() -> bool {
91 const auto &oscale = attr()->output_scales_;
92 return oscale.mask_ == 0
93 || (oscale.mask_ == (1 << (dst_md()->ndims - 1)));
94 };
95
96 auto check_attr_post_ops = [&]() -> bool {
97 using namespace primitive_kind;
98 const auto &post_ops = attr()->post_ops_;
99 static const bcast_set_t enabled_bcast_strategy {
100 broadcasting_strategy_t::scalar,
101 broadcasting_strategy_t::per_oc,
102 broadcasting_strategy_t::per_oc_spatial,
103 broadcasting_strategy_t::per_mb_spatial,
104 broadcasting_strategy_t::per_mb_w,
105 broadcasting_strategy_t::no_broadcast};
106 const bool is_binary_po_per_oc
107 = binary_injector_utils::bcast_strategy_present(
108 binary_injector_utils::extract_bcast_strategies(
109 post_ops.entry_, dst_md()),
110 broadcasting_strategy_t::per_oc);
111 return cpu::inner_product_utils::post_ops_ok(
112 post_ops, dst_md(), enabled_bcast_strategy)
113 && IMPLICATION(is_binary_po_per_oc,
114 gemm_based::check_gemm_binary_per_oc_compatible_formats(
115 *this));
116 };
117
118 // check basic attributes
119 if (!check_attr_oscale()) return status::unimplemented;
120
121 // set state
122 CHECK(params_.pp_attr_.copy_from(*attr()));
123 params_.gemm_applies_output_scales_
124 = attr()->output_scales_.mask_ == 0 && !with_bias();
125 if (params_.gemm_applies_output_scales_)
126 params_.pp_attr_.output_scales_.set(1.f);
127
128 // check post-ops
129 if (!check_attr_post_ops()) return status::unimplemented;
130 const bool sum_po_via_gemm_beta
131 = should_gemm_execute_sum_po(params_, dst_md()->data_type);
132 // set state
133 params_.dst_is_acc_
134 = IMPLICATION(attr()->post_ops_.find(primitive_kind::sum) != -1,
135 sum_po_via_gemm_beta);
136
137 if (sum_po_via_gemm_beta) {
138 // set state
139 const auto &po = params_.pp_attr_.post_ops_;
140 static constexpr int sum_idx = 0;
141 params_.gemm_beta_ = po.entry_[sum_idx].sum.scale;
142 }
143
144 // set state
145 params_.has_pp_kernel_ = !params_.dst_is_acc_ || with_bias()
146 || !params_.pp_attr_.has_default_values();
147
148 return status::success;
149 }
150
should_skip_sum_po(data_type_t dst_dt) const151 bool gemm_f32_matmul_t::should_skip_sum_po(data_type_t dst_dt) const noexcept {
152 return should_gemm_execute_sum_po(pd()->params(), dst_dt);
153 }
154
execute_ref(const exec_ctx_t & ctx) const155 status_t gemm_f32_matmul_t::execute_ref(const exec_ctx_t &ctx) const {
156 using namespace binary_injector_utils;
157 auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC);
158 auto weights = CTX_IN_MEM(const weights_data_t *, DNNL_ARG_WEIGHTS);
159 auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
160 auto dst = CTX_OUT_MEM(dst_data_t *, DNNL_ARG_DST);
161 const auto &po = this->pd()->attr()->post_ops_;
162 const auto post_ops_binary_rhs_arg_vec = prepare_binary_args(po, ctx);
163
164 DEFINE_SCALES_BUFFER(scales);
165
166 const auto src_d = ctx.memory_mdw(DNNL_ARG_SRC, pd()->src_md());
167 const auto weights_d = ctx.memory_mdw(DNNL_ARG_WEIGHTS, pd()->weights_md());
168 const auto dst_d = ctx.memory_mdw(DNNL_ARG_DST, pd()->dst_md());
169
170 matmul_helper_t helper(src_d, weights_d, dst_d);
171 const int ndims = pd()->ndims();
172 const int batch_ndims = ndims - 2;
173 dim_t M = helper.M();
174 const dim_t N = helper.N();
175 const dim_t K = helper.K();
176 const dim_t batch = helper.batch();
177 const dim_t batch_without_dim0
178 = helper.ndims() > 3 ? batch / dst_d.dims()[0] : 0;
179 const dim_t batch_without_dim01
180 = helper.ndims() > 4 ? batch_without_dim0 / dst_d.dims()[1] : 1;
181 const char transA = helper.transA();
182 const char transB = helper.transB();
183 const dim_t lda = helper.lda();
184 const dim_t ldb = helper.ldb();
185 const dim_t ldc = helper.ldc();
186 const int nthr = pd()->nthr_;
187
188 const gemm_based::params_t ¶ms = pd()->params();
189 const float alpha = params.get_gemm_alpha(scales);
190 const float beta = params.gemm_beta_;
191 const bool can_fuse_src_batch_dims = pd()->has_runtime_dims_or_strides()
192 ? helper.can_fuse_src_batch_dims()
193 : params.can_fuse_src_batch_dims_;
194 const dim_t acc_stride = gemm_based::get_scratchpad_size(
195 batch, M, N, can_fuse_src_batch_dims, nthr);
196 bool dst_is_acc = params.dst_is_acc_;
197 acc_data_t *acc = dst_is_acc
198 ? (acc_data_t *)dst
199 : ctx.get_scratchpad_grantor().template get<acc_data_t>(
200 memory_tracking::names::key_matmul_dst_in_acc_dt);
201 // case: dynamic sizes
202 bool need_free_acc = false;
203 if (acc == nullptr) {
204 acc = (acc_data_t *)malloc(sizeof(acc_data_t) * acc_stride
205 * ((can_fuse_src_batch_dims || batch == 1) ? 1 : nthr),
206 64);
207 if (acc == nullptr) return status::out_of_memory;
208 need_free_acc = true;
209 }
210
211 const dim_t acc_ldc = dst_is_acc ? ldc : N;
212 const int scale_idx_mult
213 = this->pd()->attr()->output_scales_.mask_ == (1 << (ndims - 1));
214
215 std::atomic<status_t> st(status::success);
216 // use parallel over batch when binary po with channel bcast
217 // (except batch == 1)
218 bool is_binary_po_per_oc = false;
219 bool is_binary_po_per_oc_sp = false;
220 bool is_binary_po_channel_bcast = false;
221 std::tie(is_binary_po_per_oc, is_binary_po_per_oc_sp,
222 is_binary_po_channel_bcast)
223 = bcast_strategies_present_tup(po.entry_, pd()->dst_md(),
224 broadcasting_strategy_t::per_oc,
225 broadcasting_strategy_t::per_oc_spatial,
226 broadcasting_strategy_t::per_mb_spatial);
227 // if batched, parralel over batch for per_mb_sp and per_oc binary
228 // post-op broadcast
229 const bool can_use_po_with_fused_batch = !is_binary_po_channel_bcast
230 && IMPLICATION(
231 is_binary_po_per_oc || is_binary_po_per_oc_sp, ndims == 2);
232 const bool parallel_over_batch = batch > 1 && !can_fuse_src_batch_dims;
233 if (IMPLICATION(can_use_po_with_fused_batch, parallel_over_batch)) {
234 const int src_mask
235 = utils::get_dims_mask(dst_d.dims(), src_d.dims(), ndims);
236 const int wei_mask
237 = utils::get_dims_mask(dst_d.dims(), weights_d.dims(), ndims);
238 const size_t bia_dt_size = !pd()->with_bias()
239 ? 0
240 : types::data_type_size(pd()->weights_md(1)->data_type);
241 const size_t work_amount = (size_t)batch * M * N;
242 const size_t work_per_batch = (size_t)M * N;
243 parallel(nthr, [&](int ithr, int nthr) {
244 size_t t_work_start {0}, t_work_end {0};
245 balance211(work_amount, nthr, ithr, t_work_start, t_work_end);
246
247 dim_t cur_b {0}, cur_m {0}, cur_n {0};
248 dims_t s_dims_idx, w_dims_idx, d_dims_idx;
249 size_t i_work = t_work_start;
250 const bool reuse_acc = acc != (acc_data_t *)dst;
251 acc_data_t *curr_acc
252 = reuse_acc ? acc + ithr * acc_stride : nullptr;
253
254 while (i_work < t_work_end) {
255 utils::nd_iterator_init(
256 i_work, cur_b, batch, cur_m, M, cur_n, N);
257
258 utils::l_dims_by_l_offset(
259 d_dims_idx, i_work, dst_d.dims(), ndims);
260
261 utils::copy_dims_with_mask(
262 s_dims_idx, d_dims_idx, batch_ndims, src_mask);
263 s_dims_idx[ndims - 2] = cur_m;
264 s_dims_idx[ndims - 1] = 0; // k idx is always 0
265
266 utils::copy_dims_with_mask(
267 w_dims_idx, d_dims_idx, batch_ndims, wei_mask);
268 w_dims_idx[ndims - 2] = 0; // k idx is always 0
269 w_dims_idx[ndims - 1] = cur_n;
270
271 const src_data_t *curr_src = src + src_d.off_v(s_dims_idx);
272 const weights_data_t *curr_weights
273 = weights + weights_d.off_v(w_dims_idx);
274 const dim_t dst_off = dst_d.off_v(d_dims_idx);
275 dst_data_t *curr_dst = dst + dst_off;
276 if (!reuse_acc) curr_acc = acc + dst_off;
277 dim_t gemm_M {0}, gemm_N {0};
278
279 size_t matrix_offset;
280 const size_t rem_work = t_work_end - i_work;
281 if (rem_work >= work_per_batch && cur_m == 0 && cur_n == 0) {
282 // parallel over batch
283 gemm_M = M;
284 gemm_N = N;
285 matrix_offset = 0;
286 } else if (rem_work >= (size_t)N && cur_n == 0) {
287 // parallel over M
288 gemm_M = nstl::min(
289 (size_t)(M - cur_m), (size_t)(rem_work / N));
290 gemm_N = N;
291 matrix_offset = cur_n + cur_m * N;
292 } else {
293 // parallel over N
294 gemm_M = 1;
295 gemm_N = nstl::min((size_t)(N - cur_n), rem_work);
296 matrix_offset = cur_n + cur_m * N;
297 }
298
299 status_t st_thr = extended_sgemm(&transB, &transA, &gemm_N,
300 &gemm_M, &K, &alpha, curr_weights, &ldb, curr_src, &lda,
301 &beta, curr_acc, &acc_ldc, nullptr, false);
302 if (st_thr != status::success) {
303 st = st_thr;
304 return;
305 }
306
307 if (params.has_pp_kernel_) {
308 const float *pp_scales
309 = params.get_post_processing_scales(scales);
310 const size_t dst_logical_off = i_work;
311 const size_t dim1_off = helper.ndims() > 3
312 ? ((cur_b % batch_without_dim0)
313 / batch_without_dim01)
314 : cur_m;
315
316 // offset for case with post-op broadcast_channel
317 const size_t matrix_per_first_batch_off = helper.ndims() > 3
318 ? M * N * (cur_b / batch_without_dim0)
319 + matrix_offset
320 : 0;
321 const ptrdiff_t oc_off = i_work % N;
322 (*pp_kernel_)(curr_dst, curr_acc,
323 bias + oc_off * bia_dt_size,
324 pp_scales + oc_off * scale_idx_mult, 0,
325 dst_logical_off, dim1_off, gemm_M * gemm_N,
326 static_cast<size_t>(N), ldc, nullptr,
327 post_ops_binary_rhs_arg_vec.data(), dst,
328 matrix_per_first_batch_off, ctx, *pd()->dst_md());
329 }
330 i_work += gemm_M * gemm_N;
331 }
332 });
333 } else {
334 // collapse batch into M, if weights batch dimensions are broadcasted.
335 M = batch * M;
336
337 st = extended_sgemm(&transB, &transA, &N, &M, &K, &alpha, weights, &ldb,
338 src, &lda, &beta, acc, &acc_ldc, nullptr, false);
339
340 if (st == status::success && params.has_pp_kernel_) {
341 const bool force_sequential = pp_kernel_->sequential_kernel();
342 const float *pp_scales = params.get_post_processing_scales(scales);
343 parallel(force_sequential ? 1 : nthr, [&](int ithr, int nthr) {
344 size_t start {}, end {};
345 balance211((size_t)(M * N), nthr, ithr, start, end);
346 const size_t dst_logical_off = start;
347 const size_t dst_start_row_idx = start % N;
348 (*pp_kernel_)(dst, acc, bias, pp_scales, start, dst_logical_off,
349 dst_start_row_idx, end, (size_t)N, ldc, nullptr,
350 post_ops_binary_rhs_arg_vec.data(), dst, 0, ctx,
351 *pd()->dst_md());
352 });
353 }
354 }
355
356 if (need_free_acc) free(acc);
357
358 return st;
359 }
360
361 } // namespace matmul
362 } // namespace cpu
363 } // namespace impl
364 } // namespace dnnl
365