1 /*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include <atomic>
18 
19 #include <assert.h>
20 #include <float.h>
21 #include <math.h>
22 
23 #include "common/c_types_map.hpp"
24 #include "common/dnnl_thread.hpp"
25 #include "common/type_helpers.hpp"
26 #include "common/utils.hpp"
27 
28 #include "cpu/cpu_primitive.hpp"
29 
30 #include "cpu/gemm/gemm.hpp"
31 
32 #include "cpu/binary_injector_utils.hpp"
33 #include "cpu/matmul/gemm_f32_matmul.hpp"
34 #include "cpu/matmul/matmul_utils.hpp"
35 
36 namespace dnnl {
37 namespace impl {
38 namespace cpu {
39 namespace matmul {
40 
41 using namespace data_type;
42 
init(engine_t * engine)43 status_t gemm_f32_matmul_t::pd_t::init(engine_t *engine) {
44     auto check_bias = [&]() -> bool {
45         return !with_bias()
46                 || (weights_md(1)->data_type == f32 && is_bias_1xN());
47     };
48 
49     bool ok = src_md()->data_type == src_type
50             && weights_md()->data_type == weights_type
51             && desc()->accum_data_type == acc_type
52             && dst_md()->data_type == dst_type && check_bias()
53             && attr()->has_default_values(
54                     primitive_attr_t::skip_mask_t::oscale_runtime
55                             | primitive_attr_t::skip_mask_t::post_ops
56                             | primitive_attr_t::skip_mask_t::sum_dt,
57                     dst_type)
58             && attr()->post_ops_.check_sum_consistent_dt(dst_type)
59             && set_default_formats()
60             && attr_.set_default_formats(dst_md(0)) == status::success
61             && gemm_based::check_gemm_compatible_formats(*this);
62 
63     if (!ok) return status::unimplemented;
64 
65     if (!has_runtime_dims_or_strides())
66         params_.can_fuse_src_batch_dims_
67                 = matmul_helper_t(src_md(), weights_md(), dst_md())
68                           .can_fuse_src_batch_dims();
69 
70     CHECK(check_and_configure_attributes());
71 
72     nthr_ = dnnl_get_max_threads();
73     gemm_based::book_acc_scratchpad(*this, params_, sizeof(acc_data_t), nthr_);
74 
75     return status::success;
76 }
77 
should_gemm_execute_sum_po(const gemm_based::params_t & params,data_type_t dst_dt)78 static bool should_gemm_execute_sum_po(
79         const gemm_based::params_t &params, data_type_t dst_dt) noexcept {
80     const auto &po = params.pp_attr_.post_ops_;
81     static constexpr int sum_idx = 0;
82     return po.len() > 0 && po.contain(primitive_kind::sum, sum_idx)
83             && params.gemm_applies_output_scales_
84             && po.entry_[sum_idx].sum.zero_point == 0
85             && utils::one_of(
86                     po.entry_[sum_idx].sum.dt, dst_dt, data_type::undef);
87 }
88 
check_and_configure_attributes()89 status_t gemm_f32_matmul_t::pd_t::check_and_configure_attributes() {
90     auto check_attr_oscale = [&]() -> bool {
91         const auto &oscale = attr()->output_scales_;
92         return oscale.mask_ == 0
93                 || (oscale.mask_ == (1 << (dst_md()->ndims - 1)));
94     };
95 
96     auto check_attr_post_ops = [&]() -> bool {
97         using namespace primitive_kind;
98         const auto &post_ops = attr()->post_ops_;
99         static const bcast_set_t enabled_bcast_strategy {
100                 broadcasting_strategy_t::scalar,
101                 broadcasting_strategy_t::per_oc,
102                 broadcasting_strategy_t::per_oc_spatial,
103                 broadcasting_strategy_t::per_mb_spatial,
104                 broadcasting_strategy_t::per_mb_w,
105                 broadcasting_strategy_t::no_broadcast};
106         const bool is_binary_po_per_oc
107                 = binary_injector_utils::bcast_strategy_present(
108                         binary_injector_utils::extract_bcast_strategies(
109                                 post_ops.entry_, dst_md()),
110                         broadcasting_strategy_t::per_oc);
111         return cpu::inner_product_utils::post_ops_ok(
112                        post_ops, dst_md(), enabled_bcast_strategy)
113                 && IMPLICATION(is_binary_po_per_oc,
114                         gemm_based::check_gemm_binary_per_oc_compatible_formats(
115                                 *this));
116     };
117 
118     // check basic attributes
119     if (!check_attr_oscale()) return status::unimplemented;
120 
121     // set state
122     CHECK(params_.pp_attr_.copy_from(*attr()));
123     params_.gemm_applies_output_scales_
124             = attr()->output_scales_.mask_ == 0 && !with_bias();
125     if (params_.gemm_applies_output_scales_)
126         params_.pp_attr_.output_scales_.set(1.f);
127 
128     // check post-ops
129     if (!check_attr_post_ops()) return status::unimplemented;
130     const bool sum_po_via_gemm_beta
131             = should_gemm_execute_sum_po(params_, dst_md()->data_type);
132     // set state
133     params_.dst_is_acc_
134             = IMPLICATION(attr()->post_ops_.find(primitive_kind::sum) != -1,
135                     sum_po_via_gemm_beta);
136 
137     if (sum_po_via_gemm_beta) {
138         // set state
139         const auto &po = params_.pp_attr_.post_ops_;
140         static constexpr int sum_idx = 0;
141         params_.gemm_beta_ = po.entry_[sum_idx].sum.scale;
142     }
143 
144     // set state
145     params_.has_pp_kernel_ = !params_.dst_is_acc_ || with_bias()
146             || !params_.pp_attr_.has_default_values();
147 
148     return status::success;
149 }
150 
should_skip_sum_po(data_type_t dst_dt) const151 bool gemm_f32_matmul_t::should_skip_sum_po(data_type_t dst_dt) const noexcept {
152     return should_gemm_execute_sum_po(pd()->params(), dst_dt);
153 }
154 
execute_ref(const exec_ctx_t & ctx) const155 status_t gemm_f32_matmul_t::execute_ref(const exec_ctx_t &ctx) const {
156     using namespace binary_injector_utils;
157     auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC);
158     auto weights = CTX_IN_MEM(const weights_data_t *, DNNL_ARG_WEIGHTS);
159     auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
160     auto dst = CTX_OUT_MEM(dst_data_t *, DNNL_ARG_DST);
161     const auto &po = this->pd()->attr()->post_ops_;
162     const auto post_ops_binary_rhs_arg_vec = prepare_binary_args(po, ctx);
163 
164     DEFINE_SCALES_BUFFER(scales);
165 
166     const auto src_d = ctx.memory_mdw(DNNL_ARG_SRC, pd()->src_md());
167     const auto weights_d = ctx.memory_mdw(DNNL_ARG_WEIGHTS, pd()->weights_md());
168     const auto dst_d = ctx.memory_mdw(DNNL_ARG_DST, pd()->dst_md());
169 
170     matmul_helper_t helper(src_d, weights_d, dst_d);
171     const int ndims = pd()->ndims();
172     const int batch_ndims = ndims - 2;
173     dim_t M = helper.M();
174     const dim_t N = helper.N();
175     const dim_t K = helper.K();
176     const dim_t batch = helper.batch();
177     const dim_t batch_without_dim0
178             = helper.ndims() > 3 ? batch / dst_d.dims()[0] : 0;
179     const dim_t batch_without_dim01
180             = helper.ndims() > 4 ? batch_without_dim0 / dst_d.dims()[1] : 1;
181     const char transA = helper.transA();
182     const char transB = helper.transB();
183     const dim_t lda = helper.lda();
184     const dim_t ldb = helper.ldb();
185     const dim_t ldc = helper.ldc();
186     const int nthr = pd()->nthr_;
187 
188     const gemm_based::params_t &params = pd()->params();
189     const float alpha = params.get_gemm_alpha(scales);
190     const float beta = params.gemm_beta_;
191     const bool can_fuse_src_batch_dims = pd()->has_runtime_dims_or_strides()
192             ? helper.can_fuse_src_batch_dims()
193             : params.can_fuse_src_batch_dims_;
194     const dim_t acc_stride = gemm_based::get_scratchpad_size(
195             batch, M, N, can_fuse_src_batch_dims, nthr);
196     bool dst_is_acc = params.dst_is_acc_;
197     acc_data_t *acc = dst_is_acc
198             ? (acc_data_t *)dst
199             : ctx.get_scratchpad_grantor().template get<acc_data_t>(
200                     memory_tracking::names::key_matmul_dst_in_acc_dt);
201     // case: dynamic sizes
202     bool need_free_acc = false;
203     if (acc == nullptr) {
204         acc = (acc_data_t *)malloc(sizeof(acc_data_t) * acc_stride
205                         * ((can_fuse_src_batch_dims || batch == 1) ? 1 : nthr),
206                 64);
207         if (acc == nullptr) return status::out_of_memory;
208         need_free_acc = true;
209     }
210 
211     const dim_t acc_ldc = dst_is_acc ? ldc : N;
212     const int scale_idx_mult
213             = this->pd()->attr()->output_scales_.mask_ == (1 << (ndims - 1));
214 
215     std::atomic<status_t> st(status::success);
216     // use parallel over batch when binary po with channel bcast
217     // (except batch == 1)
218     bool is_binary_po_per_oc = false;
219     bool is_binary_po_per_oc_sp = false;
220     bool is_binary_po_channel_bcast = false;
221     std::tie(is_binary_po_per_oc, is_binary_po_per_oc_sp,
222             is_binary_po_channel_bcast)
223             = bcast_strategies_present_tup(po.entry_, pd()->dst_md(),
224                     broadcasting_strategy_t::per_oc,
225                     broadcasting_strategy_t::per_oc_spatial,
226                     broadcasting_strategy_t::per_mb_spatial);
227     // if batched, parralel over batch for per_mb_sp and per_oc binary
228     // post-op broadcast
229     const bool can_use_po_with_fused_batch = !is_binary_po_channel_bcast
230             && IMPLICATION(
231                     is_binary_po_per_oc || is_binary_po_per_oc_sp, ndims == 2);
232     const bool parallel_over_batch = batch > 1 && !can_fuse_src_batch_dims;
233     if (IMPLICATION(can_use_po_with_fused_batch, parallel_over_batch)) {
234         const int src_mask
235                 = utils::get_dims_mask(dst_d.dims(), src_d.dims(), ndims);
236         const int wei_mask
237                 = utils::get_dims_mask(dst_d.dims(), weights_d.dims(), ndims);
238         const size_t bia_dt_size = !pd()->with_bias()
239                 ? 0
240                 : types::data_type_size(pd()->weights_md(1)->data_type);
241         const size_t work_amount = (size_t)batch * M * N;
242         const size_t work_per_batch = (size_t)M * N;
243         parallel(nthr, [&](int ithr, int nthr) {
244             size_t t_work_start {0}, t_work_end {0};
245             balance211(work_amount, nthr, ithr, t_work_start, t_work_end);
246 
247             dim_t cur_b {0}, cur_m {0}, cur_n {0};
248             dims_t s_dims_idx, w_dims_idx, d_dims_idx;
249             size_t i_work = t_work_start;
250             const bool reuse_acc = acc != (acc_data_t *)dst;
251             acc_data_t *curr_acc
252                     = reuse_acc ? acc + ithr * acc_stride : nullptr;
253 
254             while (i_work < t_work_end) {
255                 utils::nd_iterator_init(
256                         i_work, cur_b, batch, cur_m, M, cur_n, N);
257 
258                 utils::l_dims_by_l_offset(
259                         d_dims_idx, i_work, dst_d.dims(), ndims);
260 
261                 utils::copy_dims_with_mask(
262                         s_dims_idx, d_dims_idx, batch_ndims, src_mask);
263                 s_dims_idx[ndims - 2] = cur_m;
264                 s_dims_idx[ndims - 1] = 0; // k idx is always 0
265 
266                 utils::copy_dims_with_mask(
267                         w_dims_idx, d_dims_idx, batch_ndims, wei_mask);
268                 w_dims_idx[ndims - 2] = 0; // k idx is always 0
269                 w_dims_idx[ndims - 1] = cur_n;
270 
271                 const src_data_t *curr_src = src + src_d.off_v(s_dims_idx);
272                 const weights_data_t *curr_weights
273                         = weights + weights_d.off_v(w_dims_idx);
274                 const dim_t dst_off = dst_d.off_v(d_dims_idx);
275                 dst_data_t *curr_dst = dst + dst_off;
276                 if (!reuse_acc) curr_acc = acc + dst_off;
277                 dim_t gemm_M {0}, gemm_N {0};
278 
279                 size_t matrix_offset;
280                 const size_t rem_work = t_work_end - i_work;
281                 if (rem_work >= work_per_batch && cur_m == 0 && cur_n == 0) {
282                     // parallel over batch
283                     gemm_M = M;
284                     gemm_N = N;
285                     matrix_offset = 0;
286                 } else if (rem_work >= (size_t)N && cur_n == 0) {
287                     // parallel over M
288                     gemm_M = nstl::min(
289                             (size_t)(M - cur_m), (size_t)(rem_work / N));
290                     gemm_N = N;
291                     matrix_offset = cur_n + cur_m * N;
292                 } else {
293                     // parallel over N
294                     gemm_M = 1;
295                     gemm_N = nstl::min((size_t)(N - cur_n), rem_work);
296                     matrix_offset = cur_n + cur_m * N;
297                 }
298 
299                 status_t st_thr = extended_sgemm(&transB, &transA, &gemm_N,
300                         &gemm_M, &K, &alpha, curr_weights, &ldb, curr_src, &lda,
301                         &beta, curr_acc, &acc_ldc, nullptr, false);
302                 if (st_thr != status::success) {
303                     st = st_thr;
304                     return;
305                 }
306 
307                 if (params.has_pp_kernel_) {
308                     const float *pp_scales
309                             = params.get_post_processing_scales(scales);
310                     const size_t dst_logical_off = i_work;
311                     const size_t dim1_off = helper.ndims() > 3
312                             ? ((cur_b % batch_without_dim0)
313                                     / batch_without_dim01)
314                             : cur_m;
315 
316                     // offset for case with post-op broadcast_channel
317                     const size_t matrix_per_first_batch_off = helper.ndims() > 3
318                             ? M * N * (cur_b / batch_without_dim0)
319                                     + matrix_offset
320                             : 0;
321                     const ptrdiff_t oc_off = i_work % N;
322                     (*pp_kernel_)(curr_dst, curr_acc,
323                             bias + oc_off * bia_dt_size,
324                             pp_scales + oc_off * scale_idx_mult, 0,
325                             dst_logical_off, dim1_off, gemm_M * gemm_N,
326                             static_cast<size_t>(N), ldc, nullptr,
327                             post_ops_binary_rhs_arg_vec.data(), dst,
328                             matrix_per_first_batch_off, ctx, *pd()->dst_md());
329                 }
330                 i_work += gemm_M * gemm_N;
331             }
332         });
333     } else {
334         // collapse batch into M, if weights batch dimensions are broadcasted.
335         M = batch * M;
336 
337         st = extended_sgemm(&transB, &transA, &N, &M, &K, &alpha, weights, &ldb,
338                 src, &lda, &beta, acc, &acc_ldc, nullptr, false);
339 
340         if (st == status::success && params.has_pp_kernel_) {
341             const bool force_sequential = pp_kernel_->sequential_kernel();
342             const float *pp_scales = params.get_post_processing_scales(scales);
343             parallel(force_sequential ? 1 : nthr, [&](int ithr, int nthr) {
344                 size_t start {}, end {};
345                 balance211((size_t)(M * N), nthr, ithr, start, end);
346                 const size_t dst_logical_off = start;
347                 const size_t dst_start_row_idx = start % N;
348                 (*pp_kernel_)(dst, acc, bias, pp_scales, start, dst_logical_off,
349                         dst_start_row_idx, end, (size_t)N, ldc, nullptr,
350                         post_ops_binary_rhs_arg_vec.data(), dst, 0, ctx,
351                         *pd()->dst_md());
352             });
353         }
354     }
355 
356     if (need_free_acc) free(acc);
357 
358     return st;
359 }
360 
361 } // namespace matmul
362 } // namespace cpu
363 } // namespace impl
364 } // namespace dnnl
365