1 /*******************************************************************************
2 * Copyright 2016-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #ifndef CPU_REORDER_SIMPLE_REORDER_HPP
18 #define CPU_REORDER_SIMPLE_REORDER_HPP
19 
20 #include <assert.h>
21 
22 #include "common/bfloat16.hpp"
23 #include "common/c_types_map.hpp"
24 #include "common/dnnl_thread.hpp"
25 #include "common/math_utils.hpp"
26 #include "common/primitive.hpp"
27 #include "common/primitive_attr.hpp"
28 #include "common/tag_traits.hpp"
29 #include "common/type_helpers.hpp"
30 #include "common/utils.hpp"
31 
32 #include "cpu/cpu_primitive.hpp"
33 #include "cpu/reorder/cpu_reorder_pd.hpp"
34 
35 #include "cpu/simple_q10n.hpp"
36 
37 namespace dnnl {
38 namespace impl {
39 namespace cpu {
40 
41 using bd = block_dim_t;
42 using ib = inner_blk_t;
43 
44 template <impl::data_type_t type>
45 using data_t = typename prec_traits<type>::type;
46 
47 template <impl::data_type_t type_i, impl::data_type_t type_o>
48 using _qz_a1b0 = qz_a1b0<data_t<type_i>, data_t<type_o>>;
49 
50 template <impl::data_type_t type_i, impl::data_type_t type_o>
51 using _qz = qz<data_t<type_i>, data_t<type_o>>;
52 
53 namespace fmt_order {
54 const bool keep = true;
55 const bool reverse = false;
56 const bool any = keep;
57 } // namespace fmt_order
58 
59 namespace spec {
60 struct direct_copy {};
61 struct direct_copy_except_dim_0 {};
62 struct reference {};
63 struct conv_req_comp {}; // {s8, u8: asymmetric quantization}
64 } // namespace spec
65 
66 #define SIMPLE_REORDER_TEMPL_DECL \
67     impl::data_type_t type_i, impl::format_tag_t tag_i, \
68             impl::data_type_t type_o, impl::format_tag_t tag_o, \
69             bool order_keep
70 #define SIMPLE_REORDER_TEMPL_CALL type_i, tag_i, type_o, tag_o, order_keep
71 
72 #define DECLARE_COMMON_PARAMS() \
73     auto input = CTX_IN_MEM(const data_t<type_i> *, DNNL_ARG_FROM); \
74     auto output = CTX_OUT_MEM(data_t<type_o> *, DNNL_ARG_TO); \
75     const auto &scratchpad = ctx.get_scratchpad_grantor(); \
76     MAYBE_UNUSED(scratchpad); \
77     const auto input_d = ctx.memory_mdw(DNNL_ARG_FROM, pd->src_md()); \
78     const auto output_d = ctx.memory_mdw(DNNL_ARG_TO, pd->dst_md()); \
79     const float alpha = pd->alpha(); \
80     MAYBE_UNUSED(alpha); \
81     const float beta = pd->beta(); \
82     MAYBE_UNUSED(beta);
83 
84 #define GET_SCRATCHPAD_SIZE_ZERO() \
85     static size_t get_scratchpad_size(const memory_desc_wrapper &input_d, \
86             const memory_desc_wrapper &output_d) { \
87         return 0; \
88     }
89 
90 /* specific reorders: common template */
91 template <SIMPLE_REORDER_TEMPL_DECL, typename spec = void>
92 struct simple_reorder_impl {};
93 
94 namespace {
simple_fmt_check(bool order_keep,impl::format_tag_t tag_i,impl::format_tag_t tag_o,const memory_desc_wrapper & input_d,const memory_desc_wrapper & output_d)95 inline bool simple_fmt_check(bool order_keep, impl::format_tag_t tag_i,
96         impl::format_tag_t tag_o, const memory_desc_wrapper &input_d,
97         const memory_desc_wrapper &output_d) {
98     if (input_d.has_runtime_dims_or_strides()) return false;
99     return input_d.matches_tag(order_keep ? tag_i : tag_o)
100             && output_d.matches_tag(order_keep ? tag_o : tag_i);
101 }
simple_po_check(const primitive_attr_t * attr)102 inline bool simple_po_check(const primitive_attr_t *attr) {
103     const auto &po = attr->post_ops_;
104     return po.len() == 0
105             || (po.len() == 1 && po.contain(primitive_kind::sum, 0));
106 }
simple_attr_check(const primitive_attr_t * attr,bool many_scales_support,bool sum_support)107 inline bool simple_attr_check(const primitive_attr_t *attr,
108         bool many_scales_support, bool sum_support) {
109     using smask_t = primitive_attr_t::skip_mask_t;
110     smask_t skip_mask = smask_t::oscale;
111     if (sum_support) skip_mask = skip_mask | smask_t::post_ops;
112     if (!attr->has_default_values(skip_mask)) return false;
113     if (!attr->defined()) return false;
114     if (sum_support) simple_po_check(attr);
115     if (many_scales_support) return true;
116     return attr->output_scales_.mask_ == 0;
117 }
118 } // namespace
119 
120 /* specific reorders: implementation */
121 template <SIMPLE_REORDER_TEMPL_DECL>
122 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
123         typename utils::enable_if<tag_i == format_tag::any
124                         && utils::one_of(tag_o, format_tag::wio,
125                                 format_tag::wigo, format_tag::hwio,
126                                 format_tag::hwigo, format_tag::dhwio,
127                                 format_tag::dhwigo),
128                 spec::conv_req_comp>::type> {
is_applicablednnl::impl::cpu::simple_reorder_impl129     static bool is_applicable(const memory_desc_wrapper &input_d,
130             const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
131         using namespace data_type;
132         using namespace utils;
133 
134         if (input_d.has_runtime_dims_or_strides()) return false;
135 
136         const size_t D_mask = array_product(
137                 input_d.dims(), math::ilog2q(attr->output_scales_.mask_ + 1));
138         static constexpr bool w_groups = one_of(
139                 tag_o, format_tag::wigo, format_tag::hwigo, format_tag::dhwigo);
140         const int oc_idx = w_groups ? 1 : 0;
141         const int oc = input_d.dims()[oc_idx];
142         const int g = w_groups ? (input_d.dims()[0]) : 1;
143 
144         const bool req_comp = output_d.extra().flags
145                 & memory_extra_flags::compensation_conv_s8s8;
146         const bool req_asymmetric_comp = output_d.extra().flags
147                 & memory_extra_flags::compensation_conv_asymmetric_src;
148 
149         auto mask_ok = [&](bool check, int mask) {
150             return IMPLICATION(check, mask == (w_groups ? 0x3 : 0x1));
151         };
152 
153         return simple_attr_check(attr, true, false)
154                 && output_d.matches_tag(tag_o) && input_d.is_plain()
155                 && (req_comp || req_asymmetric_comp)
156                 && mask_ok(req_comp, output_d.extra().compensation_mask)
157                 && mask_ok(req_asymmetric_comp,
158                         output_d.extra().asymm_compensation_mask)
159                 && IMPLICATION(
160                         req_comp, one_of(D_mask, (size_t)1, (size_t)g * oc))
161                 && one_of(input_d.data_type(), f32, s8, bf16)
162                 && output_d.data_type() == s8;
163     }
164 
165     GET_SCRATCHPAD_SIZE_ZERO();
166 
executednnl::impl::cpu::simple_reorder_impl167     static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) {
168         DECLARE_COMMON_PARAMS();
169 
170         static constexpr bool w_groups = utils::one_of(
171                 tag_o, format_tag::wigo, format_tag::hwigo, format_tag::dhwigo);
172         static constexpr bool w_height
173                 = !utils::one_of(tag_o, format_tag::wio, format_tag::wigo);
174         static constexpr bool w_depth
175                 = utils::one_of(tag_o, format_tag::dhwio, format_tag::dhwigo);
176 
177         const auto &dims = input_d.dims();
178         const auto &pdims = output_d.padded_dims();
179 
180         const int G = w_groups ? dims[0] : 1;
181         const int OC = dims[w_groups + 0];
182         const int IC = dims[w_groups + 1];
183         const int D = w_depth ? dims[w_groups + 2] : 1;
184         const int H = w_height ? dims[w_groups + w_depth + 2] : 1;
185         const int W = dims[w_groups + w_depth + w_height + 2];
186 
187         const float *scales = pd->attr()->output_scales_.scales_;
188         const size_t D_mask = utils::array_product(input_d.dims(),
189                 math::ilog2q(pd->attr()->output_scales_.mask_ + 1));
190         const bool req_comp = output_d.extra().flags
191                 & memory_extra_flags::compensation_conv_s8s8;
192         const bool has_asymmetric_comp = output_d.extra().flags
193                 & memory_extra_flags::compensation_conv_asymmetric_src;
194 
195         assert(req_comp || has_asymmetric_comp);
196 
197         float adj_scale
198                 = (output_d.extra().flags & memory_extra_flags::scale_adjust)
199                 ? output_d.extra().scale_adjust
200                 : 1.f;
201 
202         size_t offset = output_d.size() - output_d.additional_buffer_size();
203         size_t zp_offset = offset
204                 + (req_comp ? G * pdims[w_groups + 0] * sizeof(int32_t) : 0);
205         int32_t *cp = req_comp ? reinterpret_cast<int32_t *>(output + offset)
206                                : nullptr;
207         int32_t *zp = has_asymmetric_comp
208                 ? reinterpret_cast<int32_t *>(output + zp_offset)
209                 : nullptr;
210 
211         parallel_nd(G, OC, [&](int g, int oc) {
212             if (req_comp) cp[g * OC + oc] = 0;
213             if (has_asymmetric_comp) zp[g * OC + oc] = 0;
214             for_(int ic = 0; ic < IC; ic++)
215             for_(int d = 0; d < D; d++)
216             for_(int h = 0; h < H; h++)
217             for (int w = 0; w < W; w++) {
218                 auto i = w_depth
219                         ? input[input_d.blk_off<!w_groups>(g, oc, ic, d, h, w)]
220                         : w_height ? input[input_d.blk_off<!w_groups>(
221                                   g, oc, ic, h, w)]
222                                    : input[input_d.blk_off<!w_groups>(
223                                            g, oc, ic, w)];
224                 auto &o = w_depth
225                         ? output[output_d.blk_off<!w_groups>(
226                                 g, oc, ic, d, h, w)]
227                         : w_height ? output[output_d.blk_off<!w_groups>(
228                                   g, oc, ic, h, w)]
229                                    : output[output_d.blk_off<!w_groups>(
230                                            g, oc, ic, w)];
231                 const float s = scales[(D_mask == 1) ? 0 : g * OC + oc];
232 
233                 o = qz_b0<data_t<type_i>, data_t<type_o>>()(i, s * adj_scale);
234                 if (req_comp) cp[g * OC + oc] -= (int32_t)o;
235                 if (has_asymmetric_comp) zp[g * OC + oc] -= (int32_t)o;
236             }
237             if (req_comp) cp[g * OC + oc] *= 128;
238         });
239         return status::success;
240     }
241 };
242 
243 template <SIMPLE_REORDER_TEMPL_DECL>
244 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
245         typename utils::enable_if<
246                 (utils::one_of(tag_i, format_tag::iwo, format_tag::oiw,
247                          format_tag::wio)
248                         && utils::one_of(tag_o, format_tag::OIw4i16o4i,
249                                 format_tag::OIw4i32o4i, format_tag::OIw4i64o4i,
250                                 format_tag::OIw2i8o4i, format_tag::OIw4o4i))
251                         || (utils::one_of(tag_i, format_tag::oi, format_tag::io)
252                                 && utils::one_of(tag_o, format_tag::OI4i16o4i,
253                                         format_tag::OI4i32o4i,
254                                         format_tag::OI4i64o4i))
255                         || (utils::one_of(
256                                     tag_i, format_tag::goiw, format_tag::wigo)
257                                 && utils::one_of(tag_o, format_tag::gOIw4i16o4i,
258                                         format_tag::gOIw2i8o4i,
259                                         format_tag::gOIw4o4i))
260                         || (utils::one_of(tag_i, format_tag::ihwo,
261                                     format_tag::hwio, format_tag::oihw)
262                                 && utils::one_of(tag_o, format_tag::OIhw4i16o4i,
263                                         format_tag::OIhw4i32o4i,
264                                         format_tag::OIhw4i64o4i,
265                                         format_tag::OIhw2i8o4i,
266                                         format_tag::OIhw4o4i))
267                         || (utils::one_of(tag_i, format_tag::idhwo,
268                                     format_tag::dhwio, format_tag::oidhw)
269                                 && utils::one_of(tag_o,
270                                         format_tag::OIdhw4i16o4i,
271                                         format_tag::OIdhw4i32o4i,
272                                         format_tag::OIdhw4i64o4i,
273                                         format_tag::OIdhw2i8o4i,
274                                         format_tag::OIdhw4o4i))
275                         || (utils::one_of(
276                                     tag_i, format_tag::goihw, format_tag::hwigo)
277                                 && utils::one_of(tag_o, format_tag::gOIhw4o4i,
278                                         format_tag::gOIhw2i8o4i,
279                                         format_tag::gOIhw4i16o4i))
280                         || (utils::one_of(tag_i, format_tag::goidhw)
281                                 && (utils::one_of(tag_o,
282                                         format_tag::gOIdhw4i16o4i,
283                                         format_tag::gOIdhw2i8o4i,
284                                         format_tag::gOIdhw4o4i))),
285                 spec::conv_req_comp>::type> {
is_applicablednnl::impl::cpu::simple_reorder_impl286     static bool is_applicable(const memory_desc_wrapper &input_d,
287             const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
288         using namespace format_tag;
289         using namespace data_type;
290         using namespace utils;
291 
292         if (input_d.has_runtime_dims_or_strides()) return false;
293 
294         const size_t D_mask = array_product(
295                 input_d.dims(), math::ilog2q(attr->output_scales_.mask_ + 1));
296         const bool w_groups = !one_of(tag_o, OIw4i16o4i, OIw2i8o4i, OIw4o4i,
297                 OIhw4i16o4i, OIhw2i8o4i, OIhw4o4i, OIdhw4i16o4i, OIdhw2i8o4i,
298                 OIdhw4o4i, OI4i16o4i, OI4i32o4i, OI4i64o4i, OIw4i32o4i,
299                 OIw4i64o4i, OIhw4i32o4i, OIhw4i64o4i, OIdhw4i32o4i,
300                 OIdhw4i64o4i);
301         const int oc = (input_d.dims()[w_groups ? 1 : 0]);
302         const int g = w_groups ? input_d.dims()[0] : 1;
303 
304         const bool req_comp = output_d.extra().flags
305                 & memory_extra_flags::compensation_conv_s8s8;
306         const bool req_asymmetric_comp = output_d.extra().flags
307                 & memory_extra_flags::compensation_conv_asymmetric_src;
308 
309         auto mask_ok = [&](bool check, int mask) {
310             return IMPLICATION(check, mask == (w_groups ? 0x3 : 0x1));
311         };
312 
313         return simple_attr_check(attr, true, false)
314                 && input_d.matches_tag(tag_i) && output_d.matches_tag(tag_o)
315                 && (req_comp || req_asymmetric_comp)
316                 && mask_ok(req_comp, output_d.extra().compensation_mask)
317                 && mask_ok(req_asymmetric_comp,
318                         output_d.extra().asymm_compensation_mask)
319                 && IMPLICATION(
320                         req_comp, one_of(D_mask, (size_t)1, (size_t)g * oc))
321                 && one_of(input_d.data_type(), f32, s8, bf16)
322                 && output_d.data_type() == s8;
323     }
324 
325     GET_SCRATCHPAD_SIZE_ZERO();
326 
executednnl::impl::cpu::simple_reorder_impl327     static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) {
328         DECLARE_COMMON_PARAMS();
329         using namespace format_tag;
330 
331         static constexpr bool w_groups = !utils::one_of(tag_o, OIw4o4i,
332                 OIw4i16o4i, OIhw4i16o4i, OIdhw4i16o4i, OIhw4o4i, OIw2i8o4i,
333                 OIhw2i8o4i, OIdhw2i8o4i, OIdhw4o4i, OI4i16o4i, OI4i32o4i,
334                 OI4i64o4i, OIw4i32o4i, OIw4i64o4i, OIhw4i32o4i, OIhw4i64o4i,
335                 OIdhw4i32o4i, OIdhw4i64o4i);
336 
337         constexpr int is_0d
338                 = utils::one_of(tag_o, OI4i16o4i, OI4i32o4i, OI4i64o4i);
339         constexpr int is_1d
340                 = utils::one_of(tag_o, gOIw4i16o4i, OIw4i16o4i, gOIw2i8o4i,
341                         OIw2i8o4i, gOIw4o4i, OIw4o4i, OIw4i32o4i, OIw4i64o4i);
342         constexpr int is_3d = utils::one_of(tag_o, gOIdhw4i16o4i, OIdhw4i16o4i,
343                 gOIdhw2i8o4i, OIdhw2i8o4i, gOIdhw4o4i, OIdhw4o4i, OIdhw4i32o4i,
344                 OIdhw4i64o4i);
345         constexpr int icblksize = utils::one_of(tag_traits<tag_o>::inner_blks,
346                                           ib::_4a4b, ib::_4b4c)
347                 ? 4
348                 : utils::one_of(tag_traits<tag_o>::inner_blks, ib::_2c8b4c,
349                           ib::_2b8a4b)
350                         ? 8
351                         : 16;
352         constexpr int ocblksize = tag_traits<tag_o>::inner_blks == ib::_4b32a4b
353                 ? 32
354                 : tag_traits<tag_o>::inner_blks == ib::_4b64a4b ? 64
355                                                                 : icblksize;
356 
357         const auto &plain_d = order_keep ? input_d : output_d;
358         const auto &dims = input_d.dims();
359         const auto &pdims
360                 = order_keep ? output_d.padded_dims() : input_d.padded_dims();
361 
362         const int G = w_groups ? dims[0] : 1;
363         const int OC = dims[w_groups + 0];
364         const int NB_OC = pdims[w_groups + 0] / ocblksize;
365         const int IC = dims[w_groups + 1];
366         const int NB_IC = pdims[w_groups + 1] / icblksize;
367         const int D = is_3d ? dims[2 + w_groups] : 1;
368         const int H = is_1d || is_0d ? 1 : dims[2 + w_groups + is_3d];
369         const int W = is_0d ? 1 : dims[w_groups + is_3d + 3 - is_1d];
370 
371         const float *scales = pd->attr()->output_scales_.scales_;
372         const size_t D_mask = utils::array_product(input_d.dims(),
373                 math::ilog2q(pd->attr()->output_scales_.mask_ + 1));
374         const bool req_comp = output_d.extra().flags
375                 & memory_extra_flags::compensation_conv_s8s8;
376         const bool has_asymmetric_comp = output_d.extra().flags
377                 & memory_extra_flags::compensation_conv_asymmetric_src;
378 
379         assert(req_comp || has_asymmetric_comp);
380 
381         float adj_scale
382                 = (output_d.extra().flags & memory_extra_flags::scale_adjust)
383                 ? output_d.extra().scale_adjust
384                 : 1.f;
385         const bool broadcast_scales = (D_mask == 1);
386 
387         // This kernel is used primarily for tensors with multiple inner
388         // blocks for which generic zero padding must be used.
389         // TODO: apply zero padding inside parallel_nd()
390         ctx.zero_pad_output(DNNL_ARG_TO);
391 
392         auto ker = [&](const data_t<type_i> *inp, data_t<type_o> *out,
393                            int32_t *c, int32_t *zp, const float *s,
394                            const int oc_block, const int ic_block) {
395 #define index AB_or_BC_blk_off<tag_traits<tag_o>::inner_blks>
396             for_(int ic = 0; ic < ic_block; ++ic)
397             for (int oc = 0; oc < oc_block; ++oc) {
398                 const auto plain_off
399                         = oc * plain_d.blocking_desc().strides[w_groups + 0]
400                         + ic * plain_d.blocking_desc().strides[w_groups + 1];
401                 out[index(oc, ic)] = qz_b0<data_t<type_i>, data_t<type_o>>()(
402                         inp[plain_off],
403                         s[broadcast_scales ? 0 : oc] * adj_scale);
404                 if (req_comp) c[oc] -= (128 * (int32_t)(out[index(oc, ic)]));
405                 if (has_asymmetric_comp)
406                     zp[oc] -= (int32_t)(out[index(oc, ic)]);
407             }
408 #undef index
409         };
410 
411         constexpr int i_mult_ic = icblksize;
412         constexpr int i_mult_oc = ocblksize;
413         constexpr int o_mult = 1;
414 
415         size_t offset
416                 = G * pdims[w_groups + 0] * pdims[w_groups + 1] * D * H * W;
417         size_t zp_offset = offset
418                 + (req_comp ? G * pdims[w_groups + 0] * sizeof(int32_t) : 0);
419         int32_t *cp = req_comp ? reinterpret_cast<int32_t *>(output + offset)
420                                : nullptr;
421         int32_t *zp = has_asymmetric_comp
422                 ? reinterpret_cast<int32_t *>(output + zp_offset)
423                 : nullptr;
424 
425         parallel_nd(G * OC, [&](dim_t i) {
426             if (req_comp) cp[i] = 0;
427             if (has_asymmetric_comp) zp[i] = 0;
428         });
429 
430 #define wei_blk_off(md, g, o, i, d, h, w) \
431     (is_0d ? (md).blk_off<!w_groups>(g, o, i) \
432            : is_1d ? (md).blk_off<!w_groups>(g, o, i, w) \
433                    : is_3d ? (md).blk_off<!w_groups>(g, o, i, d, h, w) \
434                            : (md).blk_off<!w_groups>(g, o, i, h, w))
435         parallel_nd(G, NB_OC, [&](int g, int O) {
436             for_(int I = 0; I < NB_IC; I++)
437             for_(int d = 0; d < D; d++)
438             for_(int h = 0; h < H; h++)
439             for (int w = 0; w < W; w++) {
440                 auto i = &input[wei_blk_off(
441                         input_d, g, i_mult_oc * O, i_mult_ic * I, d, h, w)];
442                 auto o = &output[wei_blk_off(
443                         output_d, g, o_mult * O, o_mult * I, d, h, w)];
444                 const int oc_block = nstl::min(ocblksize, OC - O * ocblksize);
445                 const int ic_block = nstl::min(icblksize, IC - I * icblksize);
446                 int _offset = (g * NB_OC + O) * ocblksize;
447                 ker(i, o, (order_keep && req_comp) ? &cp[_offset] : nullptr,
448                         (order_keep && has_asymmetric_comp) ? &zp[_offset]
449                                                             : nullptr,
450                         &scales[broadcast_scales ? 0 : _offset], oc_block,
451                         ic_block);
452             }
453         });
454 
455 #undef wei_blk_off
456 
457         return status::success;
458     }
459 };
460 
461 /* Asymmetric Blocking */
462 template <SIMPLE_REORDER_TEMPL_DECL>
463 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
464         typename utils::enable_if<(utils::one_of(tag_i, format_tag::iwo,
465                                            format_tag::oiw, format_tag::wio)
466                                           && utils::one_of(
467                                                   tag_o, format_tag::Owi16o))
468                         || (utils::one_of(
469                                     tag_i, format_tag::goiw, format_tag::wigo)
470                                 && utils::one_of(tag_o, format_tag::gOwi16o))
471                         || (utils::one_of(tag_i, format_tag::ihwo,
472                                     format_tag::hwio, format_tag::oihw)
473                                 && utils::one_of(tag_o, format_tag::Owhi16o))
474                         || (utils::one_of(
475                                     tag_i, format_tag::goihw, format_tag::hwigo)
476                                 && utils::one_of(tag_o, format_tag::gOwhi16o)),
477                 spec::conv_req_comp>::type> {
is_applicablednnl::impl::cpu::simple_reorder_impl478     static bool is_applicable(const memory_desc_wrapper &input_d,
479             const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
480         using namespace format_tag;
481         using namespace data_type;
482         using namespace utils;
483 
484         if (input_d.has_runtime_dims_or_strides()) return false;
485 
486         const bool w_groups = !one_of(tag_o, Owi16o, Owhi16o);
487 
488         // Current formats are only used in jit kernels that natively
489         // support s8 instructions, hence, there is no need for signed
490         // compensation.
491         const bool req_comp = output_d.extra().flags
492                 & memory_extra_flags::compensation_conv_s8s8;
493 
494         const bool req_asymmetric_comp = output_d.extra().flags
495                 & memory_extra_flags::compensation_conv_asymmetric_src;
496 
497         auto mask_ok = [&](bool check, int mask) {
498             const int c_mask = 0x1,
499                       g_mask = 0x3; // mask for i/o-channel and ngroups
500             return IMPLICATION(check, mask == (w_groups ? g_mask : c_mask));
501         };
502 
503         return simple_attr_check(attr, true, false)
504                 && input_d.matches_tag(tag_i) && output_d.matches_tag(tag_o)
505                 && mask_ok(req_asymmetric_comp,
506                         output_d.extra().asymm_compensation_mask)
507                 && one_of(input_d.data_type(), f32, s8, bf16)
508                 && output_d.data_type() == s8 && !req_comp;
509     }
510 
511     GET_SCRATCHPAD_SIZE_ZERO();
512 
executednnl::impl::cpu::simple_reorder_impl513     static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) {
514         DECLARE_COMMON_PARAMS();
515         using namespace format_tag;
516 
517         static constexpr bool w_groups = !utils::one_of(tag_o, Owi16o, Owhi16o);
518         constexpr int is_1d = utils::one_of(tag_o, Owi16o, gOwi16o);
519         const bool is_3d = false; // TODO once enabled
520 
521         constexpr int oc_blksize = 16;
522 
523         const auto &plain_d = order_keep ? input_d : output_d;
524         const auto &dims = input_d.dims();
525         const auto &pdims
526                 = order_keep ? output_d.padded_dims() : input_d.padded_dims();
527 
528         const int G = w_groups ? dims[0] : 1;
529         const int OC = dims[w_groups + 0];
530         const int NB_OC = pdims[w_groups + 0] / oc_blksize;
531         const int IC = dims[w_groups + 1];
532 
533         const int D = is_3d ? dims[2 + w_groups] : 1;
534         const int H = is_1d ? 1 : dims[2 + w_groups + is_3d];
535         const int W = dims[w_groups + is_3d + 3 - is_1d];
536 
537         const float *scales = pd->attr()->output_scales_.scales_;
538         const size_t D_mask = utils::array_product(input_d.dims(),
539                 math::ilog2q(pd->attr()->output_scales_.mask_ + 1));
540         const bool has_asymmetric_comp = output_d.extra().flags
541                 & memory_extra_flags::compensation_conv_asymmetric_src;
542 
543         float adj_scale
544                 = (output_d.extra().flags & memory_extra_flags::scale_adjust)
545                 ? output_d.extra().scale_adjust
546                 : 1.f;
547 
548         auto ker = [&](const data_t<type_i> *inp, data_t<type_o> *out,
549                            int32_t *zp, const float *s, const int oc_block) {
550             for (int oc = 0; oc < oc_block; ++oc) {
551                 const auto plain_off
552                         = oc * plain_d.blocking_desc().strides[w_groups + 0];
553                 out[oc] = qz_b0<data_t<type_i>, data_t<type_o>>()(
554                         inp[plain_off], s[oc] * adj_scale);
555                 if (has_asymmetric_comp) zp[oc] -= (int32_t)(out[oc]);
556             }
557             // fill memory with '0' in case of padded channel dimensions
558             for (int oc = oc_block; oc < oc_blksize; ++oc) {
559                 out[oc] = 0;
560             }
561         };
562 
563         size_t offset
564                 = G * pdims[w_groups + 0] * pdims[w_groups + 1] * D * H * W;
565         int32_t *zp = has_asymmetric_comp
566                 ? reinterpret_cast<int32_t *>(output + offset)
567                 : nullptr;
568 
569         if (has_asymmetric_comp) {
570             parallel_nd(G * NB_OC * oc_blksize, [&](int i) { zp[i] = 0; });
571         }
572 
573 #define wei_blk_off(md, g, o, i, d, h, w) \
574     (is_1d ? (md).blk_off<!w_groups>(g, o, i, w) \
575            : is_3d ? (md).blk_off<!w_groups>(g, o, i, d, h, w) \
576                    : (md).blk_off<!w_groups>(g, o, i, h, w))
577 
578         parallel_nd(G, NB_OC, [&](int g, int O) {
579             for_(int I = 0; I < IC; I++)
580             for_(int d = 0; d < D; d++)
581             for_(int h = 0; h < H; h++)
582             for (int w = 0; w < W; w++) {
583                 auto i = &input[wei_blk_off(
584                         input_d, g, oc_blksize * O, I, d, h, w)];
585                 auto o = &output[wei_blk_off(output_d, g, O, I, d, h, w)];
586                 const int oc_block = nstl::min(oc_blksize, OC - O * oc_blksize);
587                 int _offset = (g * NB_OC + O) * oc_blksize;
588                 ker(i, o,
589                         (order_keep && has_asymmetric_comp) ? &zp[_offset]
590                                                             : nullptr,
591                         &scales[(D_mask == 1) ? 0 : _offset], oc_block);
592             }
593         });
594 
595 #undef wei_blk_off
596 
597         return status::success;
598     }
599 };
600 
601 /* Asymmetric Blocking */
602 template <SIMPLE_REORDER_TEMPL_DECL>
603 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
604         typename utils::enable_if<(utils::one_of(tag_i, format_tag::iwo,
605                                            format_tag::oiw, format_tag::wio)
606                                           && utils::one_of(tag_o,
607                                                   format_tag::OwI16o4i,
608                                                   format_tag::OIw16i16o4i))
609                         || (utils::one_of(
610                                     tag_i, format_tag::goiw, format_tag::wigo)
611                                 && utils::one_of(tag_o, format_tag::gOwI16o4i,
612                                         format_tag::gOIw16i16o4i))
613                         || (utils::one_of(tag_i, format_tag::ihwo,
614                                     format_tag::hwio, format_tag::oihw)
615                                 && utils::one_of(tag_o, format_tag::OhwI16o4i,
616                                         format_tag::OIhw16i16o4i))
617                         || (utils::one_of(
618                                     tag_i, format_tag::goihw, format_tag::hwigo)
619                                 && utils::one_of(tag_o, format_tag::gOhwI16o4i,
620                                         format_tag::gOIhw16i16o4i))
621                         || (utils::one_of(tag_i, format_tag::idhwo,
622                                     format_tag::dhwio, format_tag::oidhw)
623                                 && utils::one_of(tag_o, format_tag::OdhwI16o4i,
624                                         format_tag::OIdhw16i16o4i))
625                         || (utils::one_of(tag_i, format_tag::goidhw)
626                                 && utils::one_of(tag_o, format_tag::gOdhwI16o4i,
627                                         format_tag::gOIdhw16i16o4i)),
628                 spec::conv_req_comp>::type> {
is_applicablednnl::impl::cpu::simple_reorder_impl629     static bool is_applicable(const memory_desc_wrapper &input_d,
630             const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
631         using namespace format_tag;
632         using namespace data_type;
633         using namespace utils;
634 
635         if (input_d.has_runtime_dims_or_strides()) return false;
636 
637         const bool w_groups = !one_of(tag_o, OwI16o4i, OIw16i16o4i, OhwI16o4i,
638                 OIhw16i16o4i, OdhwI16o4i, OIdhw16i16o4i);
639 
640         // Current formats are only used in jit kernels that natively
641         // support s8 instructions, hence, there is no need for signed
642         // compensation.
643         const bool req_comp = output_d.extra().flags
644                 & memory_extra_flags::compensation_conv_s8s8;
645 
646         const bool req_asymmetric_comp = output_d.extra().flags
647                 & memory_extra_flags::compensation_conv_asymmetric_src;
648 
649         auto mask_ok = [&](bool check, int mask) {
650             const int c_mask = 0x1,
651                       g_mask = 0x3; // mask for i/o-channel and ngroups
652             return IMPLICATION(check, mask == (w_groups ? g_mask : c_mask));
653         };
654 
655         return simple_attr_check(attr, true, false)
656                 && input_d.matches_tag(tag_i) && output_d.matches_tag(tag_o)
657                 && mask_ok(req_asymmetric_comp,
658                         output_d.extra().asymm_compensation_mask)
659                 && one_of(input_d.data_type(), f32, s8, bf16)
660                 && output_d.data_type() == s8 && !req_comp;
661     }
662 
663     GET_SCRATCHPAD_SIZE_ZERO();
664 
executednnl::impl::cpu::simple_reorder_impl665     static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) {
666         DECLARE_COMMON_PARAMS();
667         using namespace format_tag;
668 
669         static constexpr bool w_groups
670                 = !utils::one_of(tag_o, OwI16o4i, OIw16i16o4i, OhwI16o4i,
671                         OIhw16i16o4i, OdhwI16o4i, OIdhw16i16o4i);
672         constexpr int is_1d = utils::one_of(
673                 tag_o, OwI16o4i, gOwI16o4i, OIw16i16o4i, gOIw16i16o4i);
674         const bool is_3d = utils::one_of(
675                 tag_o, OdhwI16o4i, gOdhwI16o4i, OIdhw16i16o4i, gOIdhw16i16o4i);
676 
677         constexpr int oc_blksize = 16;
678         constexpr int ic_blksize = utils::one_of(tag_traits<tag_o>::inner_blks,
679                                            ib::_16b16a4b, ib::_16c16b4c)
680                 ? 64
681                 : utils::one_of(
682                           tag_traits<tag_o>::inner_blks, ib::_16a4b, ib::_16b4c)
683                         ? 4
684                         : 1;
685         assert(ic_blksize != 1);
686 
687         const auto &plain_d = order_keep ? input_d : output_d;
688         const auto &dims = input_d.dims();
689         const auto &pdims
690                 = order_keep ? output_d.padded_dims() : input_d.padded_dims();
691 
692         const int G = w_groups ? dims[0] : 1;
693         const int OC = dims[w_groups + 0];
694         const int NB_OC = pdims[w_groups + 0] / oc_blksize;
695         const int IC = dims[w_groups + 1];
696         const int NB_IC = pdims[w_groups + 1] / ic_blksize;
697 
698         const int D = is_3d ? dims[2 + w_groups] : 1;
699         const int H = is_1d ? 1 : dims[2 + w_groups + is_3d];
700         const int W = dims[w_groups + is_3d + 3 - is_1d];
701 
702         const float *scales = pd->attr()->output_scales_.scales_;
703         const size_t D_mask = utils::array_product(input_d.dims(),
704                 math::ilog2q(pd->attr()->output_scales_.mask_ + 1));
705         const bool has_asymmetric_comp = output_d.extra().flags
706                 & memory_extra_flags::compensation_conv_asymmetric_src;
707 
708         float adj_scale
709                 = (output_d.extra().flags & memory_extra_flags::scale_adjust)
710                 ? output_d.extra().scale_adjust
711                 : 1.f;
712 
713         // This kernel is used primarily for tensors with multiple inner
714         // blocks for which generic zero padding must be used.
715         // TODO: apply zero padding inside parallel_nd()
716         ctx.zero_pad_output(DNNL_ARG_TO);
717 
718         auto ker = [&](const data_t<type_i> *inp, data_t<type_o> *out,
719                            int32_t *zp, const float *s, const int oc_block,
720                            const int ic_block) {
721             for_(int ic = 0; ic < ic_block; ++ic)
722             for (int oc = 0; oc < oc_block; ++oc) {
723                 const auto plain_off
724                         = oc * plain_d.blocking_desc().strides[w_groups + 0]
725                         + ic * plain_d.blocking_desc().strides[w_groups + 1];
726                 auto index = AB_or_BC_blk_off<tag_traits<tag_o>::inner_blks>(
727                         oc, ic);
728                 out[index] = qz_b0<data_t<type_i>, data_t<type_o>>()(
729                         inp[plain_off], s[oc] * adj_scale);
730 
731                 if (has_asymmetric_comp) zp[oc] -= (int32_t)(out[index]);
732             }
733         };
734 
735         size_t offset
736                 = G * pdims[w_groups + 0] * pdims[w_groups + 1] * D * H * W;
737         int32_t *zp = has_asymmetric_comp
738                 ? reinterpret_cast<int32_t *>(output + offset)
739                 : nullptr;
740 
741         if (has_asymmetric_comp) {
742             parallel_nd(G * NB_OC * oc_blksize, [&](int i) { zp[i] = 0; });
743         }
744 
745 #define wei_blk_off(md, g, o, i, d, h, w) \
746     (is_1d ? (md).blk_off<!w_groups>(g, o, i, w) \
747            : is_3d ? (md).blk_off<!w_groups>(g, o, i, d, h, w) \
748                    : (md).blk_off<!w_groups>(g, o, i, h, w))
749 
750         parallel_nd(G, NB_OC, [&](int g, int O) {
751             for_(int I = 0; I < NB_IC; I++)
752             for_(int d = 0; d < D; d++)
753             for_(int h = 0; h < H; h++)
754             for (int w = 0; w < W; w++) {
755                 auto i = &input[wei_blk_off(
756                         input_d, g, oc_blksize * O, ic_blksize * I, d, h, w)];
757                 auto o = &output[wei_blk_off(output_d, g, O, I, d, h, w)];
758                 const int oc_block = nstl::min(oc_blksize, OC - O * oc_blksize);
759                 const int ic_block = nstl::min(ic_blksize, IC - I * ic_blksize);
760                 int _offset = (g * NB_OC + O) * oc_blksize;
761                 ker(i, o,
762                         (order_keep && has_asymmetric_comp) ? &zp[_offset]
763                                                             : nullptr,
764                         &scales[(D_mask == 1) ? 0 : _offset], oc_block,
765                         ic_block);
766             }
767         });
768 
769 #undef wei_blk_off
770 
771         return status::success;
772     }
773 };
774 
775 /* Asymmetric Blocking */
776 template <SIMPLE_REORDER_TEMPL_DECL>
777 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
778         typename utils::enable_if<
779                 (utils::one_of(tag_i, format_tag::ab, format_tag::ba)
780                         && utils::one_of(tag_o, format_tag::BA16a16b4a,
781                                 format_tag::BA16a32b4a, format_tag::BA16a48b4a,
782                                 format_tag::BA16a64b4a)),
783                 spec::conv_req_comp>::type> {
is_applicablednnl::impl::cpu::simple_reorder_impl784     static bool is_applicable(const memory_desc_wrapper &input_d,
785             const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
786         using namespace format_tag;
787         using namespace data_type;
788         using namespace utils;
789 
790         if (input_d.has_runtime_dims_or_strides()) return false;
791 
792         // Current formats are only used in jit kernels that natively
793         // support s8 instructions, hence, there is no need for signed
794         // compensation.
795         const bool req_comp = output_d.extra().flags
796                 & memory_extra_flags::compensation_conv_s8s8;
797 
798         const bool req_asymmetric_comp = output_d.extra().flags
799                 & memory_extra_flags::compensation_conv_asymmetric_src;
800 
801         auto mask_ok = [&](bool check, int mask) {
802             return IMPLICATION(check, mask == 1 << 1);
803         };
804 
805         const size_t D_mask = utils::array_product(
806                 input_d.dims(), math::ilog2q(attr->output_scales_.mask_ + 1));
807 
808         return simple_attr_check(attr, true, false)
809                 && input_d.matches_tag(tag_i) && output_d.matches_tag(tag_o)
810                 && mask_ok(req_asymmetric_comp,
811                         output_d.extra().asymm_compensation_mask)
812                 && one_of(input_d.data_type(), f32, s8, bf16)
813                 && output_d.data_type() == s8 && !req_comp && D_mask == 1;
814     }
815 
816     GET_SCRATCHPAD_SIZE_ZERO();
817 
executednnl::impl::cpu::simple_reorder_impl818     static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) {
819         DECLARE_COMMON_PARAMS();
820         using namespace format_tag;
821 
822         constexpr int A_blksize = 64;
823         constexpr int B_blksize
824                 = (tag_traits<tag_o>::inner_blks == ib::_16a64b4a)
825                 ? 64
826                 : (tag_traits<tag_o>::inner_blks == ib::_16a48b4a)
827                         ? 48
828                         : (tag_traits<tag_o>::inner_blks == ib::_16a32b4a)
829                                 ? 32
830                                 : (tag_traits<tag_o>::inner_blks
831                                           == ib::_16a16b4a)
832                                         ? 16
833                                         : 1;
834         assert(B_blksize != 1);
835 
836         const auto &plain_d = order_keep ? input_d : output_d;
837         const auto &dims = input_d.dims();
838         const auto &pdims
839                 = order_keep ? output_d.padded_dims() : input_d.padded_dims();
840 
841         const int Adim = dims[0];
842         const int NB_Adim = pdims[0] / A_blksize;
843         const int Bdim = dims[1];
844         const int NB_Bdim = pdims[1] / B_blksize;
845 
846         const float *scales = pd->attr()->output_scales_.scales_;
847         const bool has_asymmetric_comp = output_d.extra().flags
848                 & memory_extra_flags::compensation_conv_asymmetric_src;
849 
850         float adj_scale
851                 = (output_d.extra().flags & memory_extra_flags::scale_adjust)
852                 ? output_d.extra().scale_adjust
853                 : 1.f;
854 
855         auto ker = [&](const data_t<type_i> *inp, data_t<type_o> *out,
856                            int32_t *zp, const float *s, const int a_block,
857                            const int b_block) {
858             for (int a = 0; a < a_block; ++a) {
859                 for (int b = 0; b < b_block; ++b) {
860                     const auto plain_off
861                             = a * plain_d.blocking_desc().strides[0]
862                             + b * plain_d.blocking_desc().strides[1];
863                     auto index
864                             = AB_or_BC_blk_off<tag_traits<tag_o>::inner_blks>(
865                                     a, b);
866                     out[index] = qz_b0<data_t<type_i>, data_t<type_o>>()(
867                             inp[plain_off], s[0] * adj_scale);
868 
869                     if (has_asymmetric_comp) zp[b] -= (int32_t)(out[index]);
870                 }
871                 for (int b = b_block; b < B_blksize; ++b) {
872                     auto index
873                             = AB_or_BC_blk_off<tag_traits<tag_o>::inner_blks>(
874                                     a, b);
875                     out[index] = qz_b0<data_t<type_i>, data_t<type_o>>()(
876                             0, s[0] * adj_scale);
877                 }
878             }
879 
880             for_(int a = a_block; a < A_blksize; ++a)
881             for (int b = 0; b < B_blksize; ++b) {
882                 auto index
883                         = AB_or_BC_blk_off<tag_traits<tag_o>::inner_blks>(a, b);
884                 out[index] = qz_b0<data_t<type_i>, data_t<type_o>>()(
885                         0, s[0] * adj_scale);
886             }
887         };
888 
889         size_t offset = pdims[0] * pdims[1];
890         int32_t *zp = has_asymmetric_comp
891                 ? reinterpret_cast<int32_t *>(output + offset)
892                 : nullptr;
893 
894         if (has_asymmetric_comp) {
895             parallel_nd(NB_Bdim * B_blksize, [&](int i) { zp[i] = 0; });
896         }
897 
898 #define get_blk_off(md, a, b) (md).blk_off(a, b)
899 
900         parallel_nd(NB_Bdim, [&](int B) {
901             for (int A = 0; A < NB_Adim; A++) {
902                 auto i = &input[get_blk_off(
903                         input_d, A_blksize * A, B_blksize * B)];
904                 auto o = &output[get_blk_off(output_d, A, B)];
905                 const int a_block = nstl::min(A_blksize, Adim - A * A_blksize);
906                 const int b_block = nstl::min(B_blksize, Bdim - B * B_blksize);
907                 int _offset = B * B_blksize;
908                 ker(i, o,
909                         (order_keep && has_asymmetric_comp) ? &zp[_offset]
910                                                             : nullptr,
911                         &scales[0], a_block, b_block);
912             }
913         });
914 
915 #undef get_blk_off
916 
917         return status::success;
918     }
919 };
920 
921 template <SIMPLE_REORDER_TEMPL_DECL>
922 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
923         typename utils::enable_if<false
924                         || (utils::one_of(
925                                     tag_i, format_tag::goiw, format_tag::wigo)
926                                 && utils::one_of(tag_o, format_tag::Goiw16g,
927                                         format_tag::Goiw8g, format_tag::Goiw4g))
928                         || (utils::one_of(
929                                     tag_i, format_tag::goihw, format_tag::hwigo)
930                                 && utils::one_of(tag_o, format_tag::Goihw16g,
931                                         format_tag::Goihw8g,
932                                         format_tag::Goihw4g)),
933                 spec::conv_req_comp>::type> {
is_applicablednnl::impl::cpu::simple_reorder_impl934     static bool is_applicable(const memory_desc_wrapper &input_d,
935             const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
936         using namespace data_type;
937         using namespace utils;
938 
939         if (input_d.has_runtime_dims_or_strides()) return false;
940 
941         const size_t D_mask = array_product(
942                 input_d.dims(), math::ilog2q(attr->output_scales_.mask_ + 1));
943         const dim_t g = input_d.dims()[0];
944         const dim_t oc = input_d.dims()[1];
945         const dim_t ic = input_d.dims()[2];
946 
947         const bool req_comp = output_d.extra().flags
948                 & memory_extra_flags::compensation_conv_s8s8;
949         const bool req_asymmetric_comp = output_d.extra().flags
950                 & memory_extra_flags::compensation_conv_asymmetric_src;
951 
952         return order_keep && oc == 1 && ic == 1 // depth-wise case
953                 && simple_attr_check(attr, true, false)
954                 && (req_comp || req_asymmetric_comp)
955                 && input_d.matches_tag(tag_i) && output_d.matches_tag(tag_o)
956                 && IMPLICATION(
957                         req_comp, one_of(D_mask, (size_t)1, (size_t)g * oc))
958                 && one_of(input_d.data_type(), f32, s8, bf16)
959                 && output_d.data_type() == s8;
960     }
961 
962     GET_SCRATCHPAD_SIZE_ZERO();
963 
executednnl::impl::cpu::simple_reorder_impl964     static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) {
965         DECLARE_COMMON_PARAMS();
966 
967         constexpr bool is_1d
968                 = utils::one_of(tag_i, format_tag::goiw, format_tag::wigo);
969         constexpr int blksize
970                 = utils::one_of(tag_o, format_tag::Goihw4g, format_tag::Goiw4g)
971                 ? 4
972                 : utils::one_of(tag_o, format_tag::Goihw8g, format_tag::Goiw8g)
973                         ? 8
974                         : 16;
975 
976         const auto &dims = input_d.dims();
977         const auto &pdims = output_d.padded_dims();
978         const int G = dims[0];
979         const int Gp = pdims[0];
980         const int OC = dims[1];
981         const int IC = dims[2];
982         const int H = is_1d ? 1 : dims[3];
983         const int W = dims[4 - is_1d];
984         const auto zero_padding_needed = !output_d.is_dense();
985 
986         const size_t D_mask = utils::array_product(input_d.dims(),
987                 math::ilog2q(pd->attr()->output_scales_.mask_ + 1));
988         const float *scales = pd->attr()->output_scales_.scales_;
989         const bool req_comp = output_d.extra().flags
990                 & memory_extra_flags::compensation_conv_s8s8;
991         const bool has_asymmetric_comp = output_d.extra().flags
992                 & memory_extra_flags::compensation_conv_asymmetric_src;
993 
994         assert(req_comp || has_asymmetric_comp);
995 
996         float adj_scale
997                 = (output_d.extra().flags & memory_extra_flags::scale_adjust)
998                 ? output_d.extra().scale_adjust
999                 : 1.f;
1000 
1001         auto ker_out = [&](const data_t<type_i> *inp, data_t<type_o> *out,
1002                                const float *s, const int g_block) {
1003             PRAGMA_OMP_SIMD()
1004             for (int g = 0; g < g_block; g++) {
1005                 const auto i_off = g * input_d.blocking_desc().strides[0];
1006                 out[g] = qz_b0<data_t<type_i>, data_t<type_o>>()(
1007                         inp[i_off], s[g * OC] * adj_scale);
1008             }
1009         };
1010 
1011         /* Note: having separate kernels for s8 and zero-point fixes a
1012          * compiler-generated bug which results in seg-fault. */
1013         auto ker_s8 = [&](const data_t<type_o> *out, int32_t *cp,
1014                               const int g_block) {
1015             PRAGMA_OMP_SIMD()
1016             for (int g = 0; g < g_block; g++) {
1017                 cp[g * OC] -= 128 * (int32_t)(out[g]);
1018             }
1019         };
1020         auto ker_zp = [&](const data_t<type_o> *out, int32_t *zp,
1021                               const int g_block) {
1022             PRAGMA_OMP_SIMD()
1023             for (int g = 0; g < g_block; g++) {
1024                 zp[g * OC] -= (int32_t)(out[g]);
1025             }
1026         };
1027 
1028         size_t offset = output_d.size() - output_d.additional_buffer_size();
1029         size_t zp_offset = offset + (req_comp ? Gp * OC * sizeof(int32_t) : 0);
1030         int32_t *cp = req_comp ? reinterpret_cast<int32_t *>(output + offset)
1031                                : nullptr;
1032         int32_t *zp = has_asymmetric_comp
1033                 ? reinterpret_cast<int32_t *>(output + zp_offset)
1034                 : nullptr;
1035 
1036         parallel_nd((Gp / blksize) * OC, [&](int ib) {
1037             PRAGMA_OMP_SIMD()
1038             for (int i = 0; i < blksize; i++) {
1039                 if (req_comp) cp[ib * blksize + i] = 0;
1040                 if (has_asymmetric_comp) zp[ib * blksize + i] = 0;
1041             }
1042         });
1043 
1044 #define wei_blk_off(md, g, o, i, h, w) \
1045     (is_1d ? (md).blk_off(g, o, i, w) : (md).blk_off(g, o, i, h, w))
1046 
1047         parallel_nd(Gp / blksize, OC, [&](int gb, int O) {
1048             for (int I = 0; I < IC; I++) {
1049                 for_(int h = 0; h < H; h++)
1050                 for (int w = 0; w < W; w++) {
1051                     const int g_block = nstl::min(G - gb * blksize, blksize);
1052                     const auto inp = &input[wei_blk_off(
1053                             input_d, gb * blksize, O, I, h, w)];
1054                     const auto out
1055                             = &output[wei_blk_off(output_d, gb, O, I, h, w)];
1056                     int offset = gb * blksize + O;
1057 
1058                     ker_out(inp, out, &scales[(D_mask == 1) ? 0 : offset],
1059                             g_block);
1060                     if (req_comp) ker_s8(out, &cp[offset], g_block);
1061                     if (has_asymmetric_comp) ker_zp(out, &zp[offset], g_block);
1062 
1063                     if (zero_padding_needed) {
1064                         PRAGMA_OMP_SIMD()
1065                         for (int off = g_block; off < blksize; off++)
1066                             out[off] = 0;
1067                     }
1068                 }
1069             }
1070         });
1071 
1072 #undef wei_blk_off
1073 
1074         return status::success;
1075     }
1076 };
1077 
1078 /* bf16 reorders */
1079 template <SIMPLE_REORDER_TEMPL_DECL>
1080 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1081         typename utils::enable_if<(
1082                 (tag_i == format_tag::goihw || tag_i == format_tag::oihw)
1083                 && (tag_o == format_tag::gOIhw16i16o
1084                         || tag_o == format_tag::OIhw16i16o
1085                         || tag_o == format_tag::gOIhw8i16o2i
1086                         || tag_o == format_tag::OIhw8i16o2i
1087                         || tag_o == format_tag::gOIhw8o16i2o
1088                         || tag_o == format_tag::OIhw8o16i2o
1089                         || tag_o == format_tag::gIOhw8o16i2o
1090                         || tag_o == format_tag::IOhw8o16i2o)
1091                 && type_i == data_type::f32
1092                 && type_o == data_type::bf16)>::type> {
is_applicablednnl::impl::cpu::simple_reorder_impl1093     static bool is_applicable(const memory_desc_wrapper &input_d,
1094             const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
1095         using namespace data_type;
1096 
1097         if (input_d.has_runtime_dims_or_strides()) return false;
1098 
1099         return order_keep && input_d.matches_tag(tag_i)
1100                 && output_d.matches_tag(tag_o) && input_d.data_type() == f32
1101                 && output_d.data_type() == bf16 && attr->has_default_values();
1102     }
1103 
get_scratchpad_sizednnl::impl::cpu::simple_reorder_impl1104     static size_t get_scratchpad_size(const memory_desc_wrapper &input_d,
1105             const memory_desc_wrapper &output_d) {
1106         const int blksize = 16;
1107         return sizeof(float) * blksize * blksize * dnnl_get_max_threads();
1108     }
1109 
executednnl::impl::cpu::simple_reorder_impl1110     static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) {
1111         DECLARE_COMMON_PARAMS();
1112         using namespace format_tag;
1113 
1114         static constexpr bool w_groups = tag_i == goihw;
1115         const int blksize = 16;
1116         const int sblk = 2;
1117 
1118         const auto &plain_d = input_d;
1119         const auto &dims = input_d.dims();
1120         const auto &pdims = output_d.padded_dims();
1121 
1122         const int G = w_groups ? dims[0] : 1;
1123         const int OC = dims[w_groups + 0];
1124         const int NB_OC = pdims[w_groups + 0] / blksize;
1125         const int IC = dims[w_groups + 1];
1126         const int NB_IC = pdims[w_groups + 1] / blksize;
1127         const int H = dims[w_groups + 2];
1128         const int W = dims[w_groups + 3];
1129 
1130         const size_t wsp_size = blksize * blksize;
1131         float *wspace = scratchpad.template get<float>(
1132                 memory_tracking::names::key_reorder_space);
1133 
1134         auto index = [&](const int ic, const int oc) {
1135             if (utils::one_of(tag_o, gOIhw16i16o, OIhw16i16o))
1136                 return (ic * blksize + oc);
1137             else if (utils::one_of(tag_o, gOIhw8i16o2i, OIhw8i16o2i))
1138                 return ((ic / sblk) * blksize * sblk + sblk * oc + ic % sblk);
1139             else if (utils::one_of(tag_o, gOIhw8o16i2o, gIOhw8o16i2o,
1140                              OIhw8o16i2o, IOhw8o16i2o))
1141                 return ((oc / sblk) * blksize * sblk + sblk * ic + oc % sblk);
1142             else
1143                 assert(!"Invalid weight format");
1144             return 0;
1145         };
1146 
1147         auto ker = [&](const data_t<type_i> *inp, data_t<type_i> *out,
1148                            const int curr_oc_block, const int oc_block,
1149                            const int curr_ic_block, const int ic_block) {
1150             int ic = 0;
1151             for (ic = 0; ic < curr_ic_block; ++ic) {
1152                 int oc = 0;
1153                 for (oc = 0; oc < curr_oc_block; ++oc) {
1154                     const auto plain_off
1155                             = oc * plain_d.blocking_desc().strides[w_groups + 0]
1156                             + ic
1157                                     * plain_d.blocking_desc()
1158                                               .strides[w_groups + 1];
1159                     out[index(ic, oc)] = inp[plain_off];
1160                 }
1161                 for (/* continue */; oc < oc_block; ++oc) {
1162                     out[index(ic, oc)] = (data_t<type_i>)0;
1163                 }
1164             }
1165             for (/* continue */; ic < ic_block; ++ic) {
1166                 for (int oc = 0; oc < oc_block; ++oc) {
1167                     out[index(ic, oc)] = (data_t<type_i>)0;
1168                 }
1169             }
1170         };
1171 
1172         constexpr int i_mult = blksize;
1173         constexpr int o_mult = 1;
1174 
1175         parallel_nd_ext(0, G, NB_OC, NB_IC, H, W,
1176                 [&](int ithr, int, int g, int O, int I, int h, int w) {
1177                     float *_wspace = wspace + wsp_size * ithr;
1178                     auto i = &input[input_d.blk_off<!w_groups>(
1179                             g, i_mult * O, i_mult * I, h, w)];
1180                     auto o = &output[output_d.blk_off<!w_groups>(
1181                             g, o_mult * O, o_mult * I, h, w)];
1182                     const int oc_block = nstl::min(blksize, OC - O * blksize);
1183                     const int ic_block = nstl::min(blksize, IC - I * blksize);
1184                     ker(i, _wspace, oc_block, blksize, ic_block, blksize);
1185                     cvt_float_to_bfloat16(o, _wspace, wsp_size);
1186                 });
1187 
1188         return status::success;
1189     }
1190 };
1191 
1192 template <SIMPLE_REORDER_TEMPL_DECL>
1193 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1194         typename utils::enable_if<(tag_i == format_tag::nchw
1195                                           && tag_o == format_tag::nChw16c)
1196                 && type_i == data_type::f32
1197                 && type_o == data_type::bf16>::type> {
is_applicablednnl::impl::cpu::simple_reorder_impl1198     static bool is_applicable(const memory_desc_wrapper &input_d,
1199             const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
1200         using namespace data_type;
1201 
1202         if (input_d.has_runtime_dims_or_strides()) return false;
1203 
1204         return input_d.matches_tag(tag_i) && output_d.matches_tag(tag_o)
1205                 && input_d.data_type() == f32 && output_d.data_type() == bf16
1206                 && attr->has_default_values();
1207     }
1208 
get_scratchpad_sizednnl::impl::cpu::simple_reorder_impl1209     static size_t get_scratchpad_size(const memory_desc_wrapper &input_d,
1210             const memory_desc_wrapper &output_d) {
1211         const size_t blksize = 16;
1212         const size_t W = input_d.dims()[3];
1213         return sizeof(float) * blksize * W * dnnl_get_max_threads();
1214     }
1215 
executednnl::impl::cpu::simple_reorder_impl1216     static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) {
1217         DECLARE_COMMON_PARAMS();
1218 
1219         constexpr int blksize = 16;
1220 
1221         const auto &flat_d = input_d;
1222         const auto &dims = input_d.dims();
1223         const auto &pdims = output_d.padded_dims();
1224 
1225         const int C = dims[1];
1226         const int H = dims[2];
1227         const int W = dims[3];
1228 
1229         const int wsp_size = W * blksize;
1230         float *wspace = scratchpad.template get<float>(
1231                 memory_tracking::names::key_reorder_space);
1232 
1233         auto ker = [&](const data_t<type_i> *i, data_t<type_i> *o,
1234                            const int curr_c_block, const int c_block) {
1235             for (int w = 0; w < W; ++w) {
1236                 int c = 0;
1237                 for (c = 0; c < curr_c_block; ++c) {
1238                     const ptrdiff_t flat_off = 0
1239                             + c * flat_d.blocking_desc().strides[1]
1240                             + w * flat_d.blocking_desc().strides[3];
1241                     o[w * blksize + c] = i[flat_off];
1242                 }
1243                 for (/* continue */; c < c_block; ++c) {
1244                     o[w * blksize + c] = (data_t<type_i>)0;
1245                 }
1246             }
1247         };
1248 
1249         constexpr int i_c_mult = blksize;
1250         constexpr int o_c_mult = 1;
1251 
1252         parallel_nd_ext(0, dims[0], pdims[1] / blksize, H,
1253                 [&](int ithr, int, int n, int nb_c, int h) {
1254                     float *_wspace = wspace + wsp_size * ithr;
1255                     auto i = &input[input_d.blk_off(n, i_c_mult * nb_c, h)];
1256                     auto o = &output[output_d.blk_off(n, o_c_mult * nb_c, h)];
1257                     const int c_block = nstl::min(blksize, C - nb_c * blksize);
1258                     ker(i, _wspace, c_block, blksize);
1259                     cvt_float_to_bfloat16(o, _wspace, wsp_size);
1260                 });
1261 
1262         return status::success;
1263     }
1264 };
1265 
1266 /* reorders with tail support */
1267 
1268 template <SIMPLE_REORDER_TEMPL_DECL>
1269 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1270         typename utils::enable_if<false
1271                 || (utils::one_of(
1272                             tag_i, format_tag::nCdhw4c, format_tag::nCdhw8c)
1273                         && tag_o == format_tag::nCdhw16c)
1274                 || (utils::one_of(tag_i, format_tag::nChw4c, format_tag::nChw8c)
1275                         && tag_o == format_tag::nChw16c)
1276                 || (utils::one_of(tag_i, format_tag::nCw4c, format_tag::nCw8c)
1277                         && tag_o == format_tag::nCw16c)>::type> {
is_applicablednnl::impl::cpu::simple_reorder_impl1278     static bool is_applicable(const memory_desc_wrapper &input_d,
1279             const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
1280         return simple_fmt_check(order_keep, tag_i, tag_o, input_d, output_d)
1281                 && simple_attr_check(attr, false, true);
1282     }
1283 
1284     GET_SCRATCHPAD_SIZE_ZERO();
1285 
executednnl::impl::cpu::simple_reorder_impl1286     static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) {
1287         DECLARE_COMMON_PARAMS();
1288         using namespace format_tag;
1289 
1290         constexpr int is_1d = utils::one_of(tag_i, nCw4c, nCw8c);
1291         constexpr int is_3d = utils::one_of(tag_i, nCdhw4c, nCdhw8c);
1292 
1293         constexpr int blksize_i
1294                 = tag_traits<tag_i>::inner_blks == ib::_4b ? 4 : 8;
1295         constexpr int blksize_16 = 16;
1296 
1297         constexpr int ic_mult = order_keep ? blksize_16 / blksize_i : 1;
1298         constexpr int oc_mult = order_keep ? 1 : blksize_16 / blksize_i;
1299 
1300         const auto &dims = input_d.dims();
1301         const auto &pdims
1302                 = order_keep ? output_d.padded_dims() : input_d.padded_dims();
1303 
1304         const auto &d_i = order_keep ? input_d : output_d;
1305         const auto stride_C_in_blk_i = d_i.blocking_desc().strides[1];
1306 
1307         const int C = dims[1];
1308         const int D = is_3d ? dims[2] : 1;
1309         const int H = is_1d ? 1 : dims[2 + is_3d];
1310         const int W = dims[3 + is_3d - is_1d];
1311 
1312         auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o,
1313                            const int block) {
1314             const int nb = utils::div_up(block, blksize_i);
1315             if (alpha == 1.0 && beta == 0.0) {
1316                 for (int b = 0; b < nb; ++b) {
1317                     const ptrdiff_t i_off
1318                             = b * (order_keep ? stride_C_in_blk_i : blksize_i);
1319                     const ptrdiff_t o_off
1320                             = b * (order_keep ? blksize_i : stride_C_in_blk_i);
1321                     const int block_i
1322                             = nstl::min(blksize_i, block - b * blksize_i);
1323                     for (int c = 0; c < block_i; ++c) {
1324                         o[o_off + c] = _qz_a1b0<type_i, type_o>()(i[i_off + c]);
1325                     }
1326                     if (order_keep && b + 1 == nb) {
1327                         // zero padding
1328                         const auto pad_size
1329                                 = blksize_16 - ((nb - 1) * blksize_i);
1330                         const auto pad_start = block_i + o_off;
1331                         const auto pad_end = pad_size + o_off;
1332                         PRAGMA_OMP_SIMD()
1333                         for (int i = pad_start; i < pad_end; i++) {
1334                             o[i] = 0;
1335                         }
1336                     }
1337                 }
1338             } else {
1339                 for (int b = 0; b < nb; ++b) {
1340                     const ptrdiff_t i_off
1341                             = b * (order_keep ? stride_C_in_blk_i : blksize_i);
1342                     const ptrdiff_t o_off
1343                             = b * (order_keep ? blksize_i : stride_C_in_blk_i);
1344                     const int block_i
1345                             = nstl::min(blksize_i, block - b * blksize_i);
1346                     for (int c = 0; c < block_i; ++c) {
1347                         o[o_off + c] = _qz<type_i, type_o>()(
1348                                 i[i_off + c], o[o_off + c], alpha, beta);
1349                     }
1350                     if (order_keep && b + 1 == nb) {
1351                         // zero padding
1352                         const auto pad_size
1353                                 = blksize_16 - ((nb - 1) * blksize_i);
1354                         const auto pad_start = block_i + o_off;
1355                         const auto pad_end = pad_size + o_off;
1356                         PRAGMA_OMP_SIMD()
1357                         for (int i = pad_start; i < pad_end; i++) {
1358                             o[i] = 0;
1359                         }
1360                     }
1361                 }
1362             }
1363         };
1364 
1365 #define data_blk_off(md, n, c, d, h, w) \
1366     (is_1d ? (md).blk_off(n, c, w) \
1367            : is_3d ? (md).blk_off(n, c, d, h, w) : (md).blk_off(n, c, h, w))
1368 
1369         parallel_nd(dims[0], pdims[1] / blksize_16, D, H, W,
1370                 [&](int n, int nb_c, int d, int h, int w) {
1371                     auto i = &input[data_blk_off(
1372                             input_d, n, ic_mult * nb_c, d, h, w)];
1373                     auto o = &output[data_blk_off(
1374                             output_d, n, oc_mult * nb_c, d, h, w)];
1375                     const int block
1376                             = nstl::min(blksize_16, C - nb_c * blksize_16);
1377                     ker(i, o, block);
1378                 });
1379 
1380 #undef data_blk_off
1381 
1382         return status::success;
1383     }
1384 };
1385 
1386 #define PLAIN_TO_BLOCKED_IS_APPLICABLE() \
1387     static bool is_applicable(const memory_desc_wrapper &input_d, \
1388             const memory_desc_wrapper &output_d, \
1389             const primitive_attr_t *attr) { \
1390         return !input_d.has_runtime_dims_or_strides() \
1391                 && simple_attr_check(attr, false, true) \
1392                 && (order_keep ? output_d.matches_tag(tag_o) \
1393                                         && input_d.is_plain() \
1394                                : input_d.matches_tag(tag_o) \
1395                                         && output_d.is_plain()); \
1396     }
1397 
1398 template <SIMPLE_REORDER_TEMPL_DECL>
1399 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1400         typename utils::enable_if<tag_i == format_tag::any
1401                 && (tag_traits<tag_o>::block_dims == bd::_A
1402                         || tag_traits<tag_o>::block_dims == bd::_B)
1403                 && tag_traits<tag_o>::ndims >= 3
1404                 && tag_traits<tag_o>::ndims <= 6>::type> {
1405     PLAIN_TO_BLOCKED_IS_APPLICABLE();
1406 
1407     GET_SCRATCHPAD_SIZE_ZERO();
1408 
executednnl::impl::cpu::simple_reorder_impl::type1409     static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) {
1410         DECLARE_COMMON_PARAMS();
1411 
1412         const auto &flat_d = order_keep ? input_d : output_d;
1413         const auto &block_d = order_keep ? output_d : input_d;
1414         const auto &dims = input_d.dims();
1415         const auto &pdims = block_d.padded_dims();
1416 
1417         constexpr int ndims = tag_traits<tag_o>::ndims;
1418         constexpr int blk_idx = tag_traits<tag_o>::block_dims == bd::_A ? 0 : 1;
1419 
1420         const dim_t H0 = dims[0];
1421         const dim_t H1 = dims[1];
1422         const dim_t M0 = ndims >= 6 ? dims[ndims - 4] : 1;
1423         const dim_t M1 = ndims >= 5 ? dims[ndims - 3] : 1;
1424         const dim_t M2 = ndims >= 4 ? dims[ndims - 2] : 1;
1425         const dim_t L = dims[ndims - 1];
1426         const dim_t l_blk_stride = block_d.blocking_desc().strides[ndims - 1];
1427         const dim_t l_flat_stride = flat_d.blocking_desc().strides[ndims - 1];
1428         const dim_t blk_flat_stride = flat_d.blocking_desc().strides[blk_idx];
1429         using namespace data_type;
1430         using namespace utils;
1431 
1432         constexpr int blksize = false
1433                 ? 0
1434                 : one_of(tag_traits<tag_o>::inner_blks, ib::_4a, ib::_4b)
1435                         ? 4
1436                         : one_of(tag_traits<tag_o>::inner_blks, ib::_8a,
1437                                   ib::_8b)
1438                                 ? 8
1439                                 : 16;
1440 
1441         constexpr bool f32bf16
1442                 = one_of(type_i, f32, bf16) && one_of(type_o, f32, bf16);
1443 
1444         auto wrap_qz_a1b0 = [=](data_t<type_o> &out, data_t<type_i> inp) {
1445             if (f32bf16)
1446                 out = inp;
1447             else
1448                 out = _qz_a1b0<type_i, type_o>()(inp);
1449         };
1450 
1451         auto wrap_qz = [=](data_t<type_o> &out, data_t<type_i> inp, float alpha,
1452                                float beta) {
1453             if (f32bf16)
1454                 out = alpha * inp + (beta ? beta * out : 0);
1455             else
1456                 out = _qz<type_i, type_o>()(inp, out, alpha, beta);
1457         };
1458 
1459         auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o, int block) {
1460             if (alpha == 1.0 && beta == 0.0) {
1461                 for (int l = 0; l < L; ++l) {
1462                     for (int blk = 0; blk < block; ++blk) {
1463                         const dim_t flat_off
1464                                 = blk * blk_flat_stride + l * l_flat_stride;
1465                         const dim_t blk_offset = l * l_blk_stride + blk;
1466                         if (order_keep) {
1467                             wrap_qz_a1b0(o[blk_offset], i[flat_off]);
1468                         } else {
1469                             wrap_qz_a1b0(o[flat_off], i[blk_offset]);
1470                         }
1471                     }
1472                     if (order_keep) {
1473                         // zero padding
1474                         const auto pad_start = block + l * l_blk_stride;
1475                         const auto pad_end = blksize + l * l_blk_stride;
1476                         PRAGMA_OMP_SIMD()
1477                         for (int i = pad_start; i < pad_end; ++i) {
1478                             o[i] = 0;
1479                         }
1480                     }
1481                 }
1482             } else {
1483                 for (int l = 0; l < L; ++l) {
1484                     for (int blk = 0; blk < block; ++blk) {
1485                         const dim_t flat_off
1486                                 = blk * blk_flat_stride + l * l_flat_stride;
1487                         const dim_t blk_offset = l * l_blk_stride + blk;
1488                         if (order_keep)
1489                             wrap_qz(o[blk_offset], i[flat_off], alpha, beta);
1490                         else
1491                             wrap_qz(o[flat_off], i[blk_offset], alpha, beta);
1492                     }
1493                     if (order_keep) {
1494                         // zero padding
1495                         const auto pad_start = block + l * l_blk_stride;
1496                         const auto pad_end = blksize + l * l_blk_stride;
1497                         PRAGMA_OMP_SIMD()
1498                         for (int i = pad_start; i < pad_end; ++i) {
1499                             o[i] = 0;
1500                         }
1501                     }
1502                 }
1503             }
1504         };
1505 
1506 #define off(md, h0, h1, m0, m1, m2) \
1507     (ndims >= 6 ? (md).blk_off(h0, h1, m0, m1, m2) \
1508                 : ndims >= 5 ? (md).blk_off(h0, h1, m1, m2) \
1509                              : ndims >= 4 \
1510                                     ? (md).blk_off(h0, h1, m2) \
1511                                     : /* ndims >= 3 ? */ (md).blk_off(h0, h1))
1512 
1513         constexpr int i_mult = order_keep ? blksize : 1;
1514         constexpr int o_mult = order_keep ? 1 : blksize;
1515 
1516         if (blk_idx == 0) {
1517             const dim_t BH0 = pdims[0] / blksize;
1518             parallel_nd(BH0, H1, M0, M1, M2,
1519                     [&](dim_t bh0, dim_t h1, dim_t m0, dim_t m1, dim_t m2) {
1520                         auto i = &input[off(
1521                                 input_d, bh0 * i_mult, h1, m0, m1, m2)];
1522                         auto o = &output[off(
1523                                 output_d, bh0 * o_mult, h1, m0, m1, m2)];
1524                         const int block
1525                                 = nstl::min<int>(blksize, H0 - bh0 * blksize);
1526                         ker(i, o, block);
1527                     });
1528         } else if (blk_idx == 1) {
1529             const dim_t BH1 = pdims[1] / blksize;
1530             parallel_nd(H0, BH1, M0, M1, M2,
1531                     [&](dim_t h0, dim_t bh1, dim_t m0, dim_t m1, dim_t m2) {
1532                         auto i = &input[off(
1533                                 input_d, h0, bh1 * i_mult, m0, m1, m2)];
1534                         auto o = &output[off(
1535                                 output_d, h0, bh1 * o_mult, m0, m1, m2)];
1536                         const int block
1537                                 = nstl::min<int>(blksize, H1 - bh1 * blksize);
1538                         ker(i, o, block);
1539                     });
1540         } else {
1541             assert(!"unimplemented");
1542         }
1543 
1544 #undef off
1545 
1546         return status::success;
1547     }
1548 };
1549 
1550 template <SIMPLE_REORDER_TEMPL_DECL>
1551 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1552         typename utils::enable_if<tag_i == format_tag::any
1553                 && (tag_traits<tag_o>::block_dims == bd::_AB
1554                         || tag_traits<tag_o>::block_dims == bd::_BC)
1555                 && IMPLICATION(tag_traits<tag_o>::block_dims == bd::_AB,
1556                         tag_traits<tag_o>::ndims >= 3
1557                                 && tag_traits<tag_o>::ndims <= 5)
1558                 && IMPLICATION(tag_traits<tag_o>::block_dims == bd::_BC,
1559                         tag_traits<tag_o>::ndims >= 4
1560                                 && tag_traits<tag_o>::ndims <= 6)>::type> {
1561     PLAIN_TO_BLOCKED_IS_APPLICABLE();
1562 
1563     GET_SCRATCHPAD_SIZE_ZERO();
1564 
executednnl::impl::cpu::simple_reorder_impl1565     static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) {
1566         DECLARE_COMMON_PARAMS();
1567 
1568         const auto &flat_d = order_keep ? input_d : output_d;
1569         const auto &dims = input_d.dims();
1570         const auto &pdims
1571                 = order_keep ? output_d.padded_dims() : input_d.padded_dims();
1572 
1573         constexpr int ndims = tag_traits<tag_o>::ndims;
1574 
1575         static constexpr bool with_g = tag_traits<tag_o>::block_dims == bd::_BC;
1576         const dim_t G = with_g ? dims[0] : 1;
1577 
1578         const dim_t H0 = dims[0 + with_g];
1579         const dim_t H1 = dims[1 + with_g];
1580 
1581         const dim_t M0 = ndims >= 5 + with_g ? dims[ndims - 3] : 1;
1582         const dim_t M1 = ndims >= 4 + with_g ? dims[ndims - 2] : 1;
1583         const dim_t M2 = ndims >= 3 + with_g ? dims[ndims - 1] : 1;
1584 
1585         const dim_t h0_flat_stride = flat_d.blocking_desc().strides[with_g + 0];
1586         const dim_t h1_flat_stride = flat_d.blocking_desc().strides[with_g + 1];
1587         using namespace data_type;
1588         using namespace utils;
1589 
1590         constexpr int blksize_0 = false
1591                 ? 0
1592                 : one_of(tag_traits<tag_o>::inner_blks, ib::_4b4a, ib::_4b4c,
1593                           ib::_4c4b)
1594                         ? 4
1595                         : one_of(tag_traits<tag_o>::inner_blks, ib::_8a8b,
1596                                   ib::_8b8a, ib::_8b8c, ib::_8c8b, ib::_2c8b4c)
1597                                 ? 8
1598                                 : one_of(tag_traits<tag_o>::inner_blks,
1599                                           ib::_16a16b, ib::_16b16a, ib::_16b16c,
1600                                           ib::_16c16b, ib::_8a16b2a,
1601                                           ib::_4b16a4b, ib::_8b16a2b,
1602                                           ib::_8b16c2b, ib::_4c16b4c,
1603                                           ib::_8c16b2c)
1604                                         ? 16
1605                                         : -1;
1606 
1607         constexpr int blksize_1
1608                 = one_of(tag_traits<tag_o>::inner_blks, ib::_8a8b, ib::_8b8a,
1609                           ib::_8b8c, ib::_8c8b, ib::_2c8b4c)
1610                 ? 8
1611                 : one_of(tag_traits<tag_o>::inner_blks, ib::_16a16b,
1612                           ib::_16b16a, ib::_16b16c, ib::_16c16b, ib::_8a16b2a,
1613                           ib::_4b16a4b, ib::_8b16a2b, ib::_8b16c2b,
1614                           ib::_4c16b4c, ib::_8c16b2c)
1615                         ? 16
1616                         : one_of(tag_traits<tag_o>::inner_blks, ib::_4b4a,
1617                                   ib::_4b4c, ib::_4c4b)
1618                                 ? 4
1619                                 : -1;
1620 
1621         const dim_t NB_H0 = pdims[0 + with_g] / blksize_0;
1622         const dim_t NB_H1 = pdims[1 + with_g] / blksize_1;
1623 
1624         constexpr bool f32bf16
1625                 = one_of(type_i, f32, bf16) && one_of(type_o, f32, bf16);
1626 
1627         auto wrap_qz_a1b0 = [=](data_t<type_o> &out, data_t<type_i> inp) {
1628             if (f32bf16)
1629                 out = inp;
1630             else
1631                 out = _qz_a1b0<type_i, type_o>()(inp);
1632         };
1633 
1634         auto wrap_qz = [=](data_t<type_o> &out, data_t<type_i> inp, float alpha,
1635                                float beta) {
1636             if (f32bf16)
1637                 out = alpha * inp + (beta ? beta * out : 0);
1638             else
1639                 out = _qz<type_i, type_o>()(inp, out, alpha, beta);
1640         };
1641 
1642         auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o,
1643                            const int block_h0, const int block_h1) {
1644 #define blk_off AB_or_BC_blk_off<tag_traits<tag_o>::inner_blks>
1645             if (alpha == 1.0 && beta == 0.0) {
1646                 for (int h0 = 0; h0 < block_h0; ++h0) {
1647                     for (int h1 = 0; h1 < block_h1; ++h1) {
1648                         const dim_t flat_off
1649                                 = h0 * h0_flat_stride + h1 * h1_flat_stride;
1650                         if (order_keep)
1651                             wrap_qz_a1b0(o[blk_off(h0, h1)], i[flat_off]);
1652                         else
1653                             wrap_qz_a1b0(o[flat_off], i[blk_off(h0, h1)]);
1654                     }
1655                     if (order_keep && block_h1 < blksize_1) {
1656                         // zero padding
1657                         PRAGMA_OMP_SIMD()
1658                         for (int h1 = block_h1; h1 < blksize_1; h1++) {
1659                             o[blk_off(h0, h1)] = 0;
1660                         }
1661                     }
1662                 }
1663                 if (order_keep && block_h0 < blksize_0) {
1664                     // zero padding
1665                     for (int h0 = block_h0; h0 < blksize_0; h0++) {
1666                         PRAGMA_OMP_SIMD()
1667                         for (int h1 = 0; h1 < blksize_1; ++h1) {
1668                             o[blk_off(h0, h1)] = 0;
1669                         }
1670                     }
1671                 }
1672             } else {
1673                 for (int h0 = 0; h0 < block_h0; ++h0) {
1674                     for (int h1 = 0; h1 < block_h1; ++h1) {
1675                         const dim_t flat_off
1676                                 = h0 * h0_flat_stride + h1 * h1_flat_stride;
1677                         if (order_keep)
1678                             wrap_qz(o[blk_off(h0, h1)], i[flat_off], alpha,
1679                                     beta);
1680                         else
1681                             wrap_qz(o[flat_off], i[blk_off(h0, h1)], alpha,
1682                                     beta);
1683                     }
1684                     if (order_keep && block_h1 < blksize_1) {
1685                         // zero padding
1686                         PRAGMA_OMP_SIMD()
1687                         for (int h1 = block_h1; h1 < blksize_1; h1++) {
1688                             o[blk_off(h0, h1)] = 0;
1689                         }
1690                     }
1691                 }
1692                 if (order_keep && block_h0 < blksize_0) {
1693                     // zero padding
1694                     for (int h0 = block_h0; h0 < blksize_0; h0++) {
1695                         PRAGMA_OMP_SIMD()
1696                         for (int h1 = 0; h1 < blksize_1; ++h1) {
1697                             o[blk_off(h0, h1)] = 0;
1698                         }
1699                     }
1700                 }
1701             }
1702 
1703 #undef blk_off
1704         };
1705 
1706         constexpr int i_mult_0 = order_keep ? blksize_0 : 1;
1707         constexpr int o_mult_0 = order_keep ? 1 : blksize_0;
1708 
1709         constexpr int i_mult_1 = order_keep ? blksize_1 : 1;
1710         constexpr int o_mult_1 = order_keep ? 1 : blksize_1;
1711 
1712 #define off(md, g, h0, h1, m0, m1, m2) \
1713     (ndims >= 5 + with_g ? (md).blk_off<!with_g>(g, h0, h1, m0, m1, m2) \
1714                          : ndims >= 4 + with_g \
1715                             ? (md).blk_off<!with_g>(g, h0, h1, m1, m2) \
1716                             : /* ndims >= 3 + with_g ? */ (md) \
1717                                       .blk_off<!with_g>(g, h0, h1, m2))
1718 
1719         parallel_nd(G, NB_H0, NB_H1, M0, M1, M2,
1720                 [&](dim_t g, dim_t nb_h0, dim_t nb_h1, dim_t m0, dim_t m1,
1721                         dim_t m2) {
1722                     auto i = &input[off(input_d, g, i_mult_0 * nb_h0,
1723                             i_mult_1 * nb_h1, m0, m1, m2)];
1724                     auto o = &output[off(output_d, g, o_mult_0 * nb_h0,
1725                             o_mult_1 * nb_h1, m0, m1, m2)];
1726                     const int block_h0
1727                             = nstl::min<int>(blksize_0, H0 - nb_h0 * blksize_0);
1728                     const int block_h1
1729                             = nstl::min<int>(blksize_1, H1 - nb_h1 * blksize_1);
1730                     ker(i, o, block_h0, block_h1);
1731                 });
1732 
1733 #undef off
1734 
1735         return status::success;
1736     }
1737 };
1738 
1739 /* generic and direct-copy reorders */
1740 
1741 template <SIMPLE_REORDER_TEMPL_DECL>
1742 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1743         typename utils::enable_if<tag_i == format_tag::any
1744                         && tag_o == format_tag::any
1745                         && order_keep == fmt_order::any,
1746                 spec::direct_copy>::type> {
is_applicablednnl::impl::cpu::simple_reorder_impl1747     static bool is_applicable(const memory_desc_wrapper &input_d,
1748             const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
1749         /* FIXME: is the formula correct? */
1750         return !input_d.has_runtime_dims_or_strides()
1751                 && input_d.similar_to(output_d, true, false, 0)
1752                 && input_d.is_dense() && output_d.is_dense()
1753                 && simple_attr_check(attr, false, true);
1754     }
1755 
1756     GET_SCRATCHPAD_SIZE_ZERO();
1757 
executednnl::impl::cpu::simple_reorder_impl1758     static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) {
1759         DECLARE_COMMON_PARAMS();
1760 
1761         assert(input_d.is_dense());
1762 
1763         input += input_d.blk_off(0);
1764         output += output_d.blk_off(0);
1765 
1766         const size_t nelems = input_d.nelems();
1767 
1768         constexpr int block_size = 16;
1769         const auto num_blocks = nelems / block_size;
1770         const auto rem_elems = nelems % block_size;
1771 
1772         parallel(0, [&](const int ithr, const int nthr) {
1773             size_t start {0}, end {0};
1774             balance211(num_blocks, nthr, ithr, start, end);
1775             start = start * block_size;
1776             end = end * block_size;
1777 
1778             if (alpha == 1.0 && beta == 0.0) {
1779                 PRAGMA_OMP_SIMD()
1780                 for (size_t e = start; e < end; ++e) {
1781                     output[e] = qz_a1b0<data_t<type_i>, data_t<type_o>>()(
1782                             input[e]);
1783                 }
1784             } else if (alpha == 1.0) {
1785                 PRAGMA_OMP_SIMD()
1786                 for (size_t e = start; e < end; ++e) {
1787                     output[e] = qz_a1<data_t<type_i>, data_t<type_o>>()(
1788                             input[e], output[e], beta);
1789                 }
1790             } else if (beta == 0.0) {
1791                 PRAGMA_OMP_SIMD()
1792                 for (size_t e = start; e < end; ++e) {
1793                     output[e] = qz_b0<data_t<type_i>, data_t<type_o>>()(
1794                             input[e], alpha);
1795                 }
1796             } else {
1797                 PRAGMA_OMP_SIMD()
1798                 for (size_t e = start; e < end; ++e) {
1799                     output[e] = qz<data_t<type_i>, data_t<type_o>>()(
1800                             input[e], output[e], alpha, beta);
1801                 }
1802             }
1803 
1804             if (rem_elems != 0 && ithr == nthr - 1) {
1805                 if (alpha == 1.0 && beta == 0.0) {
1806                     PRAGMA_OMP_SIMD()
1807                     for (size_t e = nelems - rem_elems; e < nelems; ++e) {
1808                         output[e] = qz_a1b0<data_t<type_i>, data_t<type_o>>()(
1809                                 input[e]);
1810                     }
1811                 } else if (alpha == 1.0) {
1812                     PRAGMA_OMP_SIMD()
1813                     for (size_t e = nelems - rem_elems; e < nelems; ++e) {
1814                         output[e] = qz_a1<data_t<type_i>, data_t<type_o>>()(
1815                                 input[e], output[e], beta);
1816                     }
1817                 } else if (beta == 0.0) {
1818                     PRAGMA_OMP_SIMD()
1819                     for (size_t e = nelems - rem_elems; e < nelems; ++e) {
1820                         output[e] = qz_b0<data_t<type_i>, data_t<type_o>>()(
1821                                 input[e], alpha);
1822                     }
1823                 } else {
1824                     PRAGMA_OMP_SIMD()
1825                     for (size_t e = nelems - rem_elems; e < nelems; ++e) {
1826                         output[e] = qz<data_t<type_i>, data_t<type_o>>()(
1827                                 input[e], output[e], alpha, beta);
1828                     }
1829                 }
1830             }
1831         });
1832         return status::success;
1833     }
1834 };
1835 
1836 template <SIMPLE_REORDER_TEMPL_DECL>
1837 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1838         typename utils::enable_if<tag_i == format_tag::any
1839                         && tag_o == format_tag::any
1840                         && order_keep == fmt_order::any,
1841                 spec::direct_copy_except_dim_0>::type> {
is_applicablednnl::impl::cpu::simple_reorder_impl1842     static bool is_applicable(const memory_desc_wrapper &input_d,
1843             const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
1844         auto is_dense_no_0 = [](const memory_desc_wrapper &data_d) {
1845             return nelems_no_dim_0(data_d) == _size_no_dim_0(data_d);
1846         };
1847         /* FIXME: is the formula correct? */
1848         return !input_d.has_runtime_dims_or_strides()
1849                 && input_d.similar_to(output_d, true, false, 1)
1850                 && is_dense_no_0(input_d) && is_dense_no_0(output_d)
1851                 && simple_attr_check(attr, false, true);
1852     }
1853 
1854     GET_SCRATCHPAD_SIZE_ZERO();
1855 
executednnl::impl::cpu::simple_reorder_impl1856     static status_t execute(const cpu_reorder_pd_t *pd, const exec_ctx_t &ctx) {
1857         DECLARE_COMMON_PARAMS();
1858         using namespace utils;
1859 
1860         input += input_d.blk_off(0);
1861         output += output_d.blk_off(0);
1862 
1863         const int N = input_d.dims()[0];
1864         const dim_t is = input_d.blocking_desc().strides[0];
1865         const dim_t os = output_d.blocking_desc().strides[0];
1866         const dim_t nelems_no_d0 = nelems_no_dim_0(input_d);
1867         const dim_t work_amount = N * nelems_no_d0;
1868 
1869         if (alpha == 1.0 && beta == 0.0) {
1870             parallel(0, [&](const int ithr, const int nthr) {
1871                 dim_t n {0}, dim1_s {0};
1872                 dim_t start {0}, end {0};
1873                 balance211(work_amount, nthr, ithr, start, end);
1874                 nd_iterator_init(start, n, N, dim1_s, nelems_no_d0);
1875                 while (start < end) {
1876                     dim_t work_rem = end - start;
1877                     dim_t dim1_e = dim1_s + work_rem > nelems_no_d0
1878                             ? nelems_no_d0
1879                             : dim1_s + work_rem;
1880                     PRAGMA_OMP_SIMD()
1881                     for (dim_t e = dim1_s; e < dim1_e; ++e) {
1882                         output[os * n + e]
1883                                 = _qz_a1b0<type_i, type_o>()(input[is * n + e]);
1884                     }
1885                     nd_iterator_jump(start, end, n, N, dim1_s, nelems_no_d0);
1886                 }
1887             });
1888         } else {
1889             parallel(0, [&](const int ithr, const int nthr) {
1890                 dim_t n {0}, dim1_s {0};
1891                 dim_t start {0}, end {0};
1892                 balance211(work_amount, nthr, ithr, start, end);
1893                 nd_iterator_init(start, n, N, dim1_s, nelems_no_d0);
1894                 while (start < end) {
1895                     dim_t work_rem = end - start;
1896                     dim_t dim1_e = dim1_s + work_rem > nelems_no_d0
1897                             ? nelems_no_d0
1898                             : dim1_s + work_rem;
1899                     PRAGMA_OMP_SIMD()
1900                     for (dim_t e = dim1_s; e < dim1_e; ++e) {
1901                         output[os * n + e]
1902                                 = _qz<type_i, type_o>()(input[is * n + e],
1903                                         output[os * n + e], alpha, beta);
1904                     }
1905                     nd_iterator_jump(start, end, n, N, dim1_s, nelems_no_d0);
1906                 }
1907             });
1908         }
1909 
1910         return status::success;
1911     }
1912 
1913 private:
nelems_no_dim_0dnnl::impl::cpu::simple_reorder_impl1914     static dim_t nelems_no_dim_0(const memory_desc_wrapper &data_d) {
1915         const int ndims = data_d.ndims();
1916         if (ndims <= 1) return 1;
1917         return utils::array_product(data_d.dims() + 1, data_d.ndims() - 1);
1918     }
1919 
_size_no_dim_0dnnl::impl::cpu::simple_reorder_impl1920     static dim_t _size_no_dim_0(const memory_desc_wrapper &data_d) {
1921         dims_t blocks;
1922         data_d.compute_blocks(blocks);
1923 
1924         const auto &blk = data_d.blocking_desc();
1925 
1926         dim_t blk_size = 1;
1927         for (int iblk = 0; iblk < blk.inner_nblks; ++iblk)
1928             blk_size *= blk.inner_blks[iblk];
1929 
1930         dim_t max_size = blk_size;
1931         for (int d = 1; d < data_d.ndims(); ++d) {
1932             max_size = nstl::max(max_size,
1933                     data_d.padded_dims()[d] / blocks[d] * blk.strides[d]);
1934         }
1935 
1936         return max_size;
1937     }
1938 };
1939 
1940 template <SIMPLE_REORDER_TEMPL_DECL>
1941 struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
1942         typename utils::enable_if<tag_i == format_tag::any
1943                         && tag_o == format_tag::any
1944                         && order_keep == fmt_order::any,
1945                 spec::reference>::type> {
is_applicablednnl::impl::cpu::simple_reorder_impl1946     static bool is_applicable(const memory_desc_wrapper &input_d,
1947             const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
1948         /* supported smask: 0x0...011..10...0,
1949          * i.e. 1 should be contiguous */
1950         int smask = attr ? attr->output_scales_.mask_ : 0;
1951         for (; smask > 0 && !(smask & 0x1); smask >>= 1)
1952             ;
1953         for (; smask > 0 && smask & 0x1; smask >>= 1)
1954             ;
1955         return input_d.is_blocking_desc() && output_d.is_blocking_desc()
1956                 && !output_d.is_additional_buffer()
1957                 && !input_d.is_additional_buffer() && smask == 0
1958                 && attr->has_default_values(
1959                         dnnl_primitive_attr::skip_mask_t::oscale_runtime
1960                         | dnnl_primitive_attr::skip_mask_t::zero_points_runtime
1961                         | dnnl_primitive_attr::skip_mask_t::post_ops)
1962                 && simple_po_check(attr);
1963     }
1964 
1965     GET_SCRATCHPAD_SIZE_ZERO();
1966 
executednnl::impl::cpu::simple_reorder_impl1967     static status_t execute(
1968             const cpu_reorder_pd_t *pd_object, const exec_ctx_t &ctx) {
1969         // DEFINE_SCALES_BUFFER and DEFINE_ZERO_POINT_VALUE macro use pd() to
1970         // query properties, hence wrapping the primitive descriptor into a
1971         // function.
1972         auto pd = [pd_object]() { return pd_object; };
1973 
1974         auto input = CTX_IN_MEM(const data_t<type_i> *, DNNL_ARG_FROM);
1975         auto output = CTX_OUT_MEM(data_t<type_o> *, DNNL_ARG_TO);
1976 
1977         const float beta = pd()->beta();
1978         DEFINE_SCALES_BUFFER(scales);
1979         DEFINE_ZERO_POINT_VALUE(i0, DNNL_ARG_FROM);
1980         DEFINE_ZERO_POINT_VALUE(o0, DNNL_ARG_TO);
1981 
1982         const auto input_d = ctx.memory_mdw(DNNL_ARG_FROM, pd()->src_md());
1983         const auto output_d = ctx.memory_mdw(DNNL_ARG_TO, pd()->dst_md());
1984 
1985         const size_t nelems = input_d.nelems();
1986 
1987         // This kernel is used also for tensors with multiple inner
1988         // blocks for which generic zero padding must be used.
1989         // TODO: apply zero padding inside parallel_nd()
1990         ctx.zero_pad_output(DNNL_ARG_TO);
1991 
1992         int ndims_start = 0, ndims_mask = 0;
1993         int smask = pd()->attr()->output_scales_.mask_;
1994         for (; smask > 0 && !(smask & 0x1); smask >>= 1)
1995             ++ndims_start;
1996         for (; smask > 0 && smask & 0x1; smask >>= 1)
1997             ++ndims_mask;
1998         assert(smask == 0);
1999 
2000         const ptrdiff_t D_start
2001                 = utils::array_product(input_d.dims(), ndims_start);
2002         const ptrdiff_t D_mask = utils::array_product(
2003                 input_d.dims() + ndims_start, ndims_mask);
2004         const ptrdiff_t D_rest = nelems / D_start / D_mask;
2005 
2006         parallel_nd(D_start, D_mask, D_rest,
2007                 [&](ptrdiff_t ds, ptrdiff_t dm, ptrdiff_t dr) {
2008                     const float scale = scales[dm];
2009 
2010                     const size_t e = (ds * D_mask + dm) * D_rest + dr;
2011                     const auto &i = input[input_d.off_l(e)];
2012                     auto &o = output[output_d.off_l(e)];
2013 
2014                     float f = scale * ((float)i - i0) + o0;
2015                     o = _qz<data_type::f32, type_o>()(f, o, 1.f, beta);
2016                 });
2017 
2018         return status::success;
2019     }
2020 };
2021 
2022 /* high level class declaration */
2023 
2024 template <SIMPLE_REORDER_TEMPL_DECL, typename spec = void>
2025 struct simple_reorder_t : public primitive_t {
2026     struct pd_t : public cpu_reorder_pd_t {
2027         using cpu_reorder_pd_t::cpu_reorder_pd_t;
2028 
2029         DECLARE_COMMON_PD_T("simple:any", simple_reorder_t);
2030 
2031     private:
creatednnl::impl::cpu::simple_reorder_t::pd_t2032         static status_t create(reorder_pd_t **reorder_pd, engine_t *engine,
2033                 const primitive_attr_t *attr, engine_t *src_engine,
2034                 const memory_desc_t *src_md, engine_t *dst_engine,
2035                 const memory_desc_t *dst_md) {
2036             bool args_ok = true && src_md->data_type == type_i
2037                     && dst_md->data_type == type_o
2038                     && attr->has_default_values(
2039                             dnnl_primitive_attr::skip_mask_t::oscale_runtime
2040                             | dnnl_primitive_attr::skip_mask_t::zero_points
2041                             | dnnl_primitive_attr::skip_mask_t::
2042                                     zero_points_runtime
2043                             | dnnl_primitive_attr::skip_mask_t::post_ops)
2044                     && simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
2045                             spec>::is_applicable(src_md, dst_md, attr);
2046             if (!args_ok) return status::invalid_arguments;
2047 
2048             auto _pd = new pd_t(attr, src_engine->kind(), src_md,
2049                     dst_engine->kind(), dst_md);
2050             if (_pd == nullptr) return status::out_of_memory;
2051             if (_pd->init(engine, src_engine, dst_engine) != status::success) {
2052                 delete _pd;
2053                 return status::unimplemented;
2054             }
2055 
2056             const size_t scratchpad_sz_
2057                     = simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
2058                             spec>::get_scratchpad_size(src_md, dst_md);
2059             auto scratchpad = _pd->scratchpad_registry().registrar();
2060             scratchpad.book(memory_tracking::names::key_reorder_space,
2061                     scratchpad_sz_, 1, 16);
2062             _pd->init_scratchpad_md();
2063             return safe_ptr_assign(*reorder_pd, _pd);
2064         }
2065         friend dnnl::impl::impl_list_item_t;
2066     };
2067 
simple_reorder_tdnnl::impl::cpu::simple_reorder_t2068     simple_reorder_t(const pd_t *apd) : primitive_t(apd) {}
2069 
executednnl::impl::cpu::simple_reorder_t2070     status_t execute(const exec_ctx_t &ctx) const override {
2071         return simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL, spec>::execute(
2072                 pd(), ctx);
2073     }
2074 
2075 private:
pddnnl::impl::cpu::simple_reorder_t2076     const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
2077 };
2078 
2079 #undef SIMPLE_REORDER_TEMPL_DECL
2080 #undef SIMPLE_REORDER_TEMPL_CALL
2081 
2082 } // namespace cpu
2083 } // namespace impl
2084 } // namespace dnnl
2085 
2086 #endif
2087 
2088 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
2089