1 /*******************************************************************************
2 * Copyright 2016-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #ifndef CPU_X64_JIT_PRIMITIVE_CONF_HPP
18 #define CPU_X64_JIT_PRIMITIVE_CONF_HPP
19 
20 #include <queue>
21 #include <stdint.h>
22 
23 #include "common/primitive_attr.hpp"
24 #include "cpu/x64/brgemm/brgemm_types.hpp"
25 #include "cpu/x64/cpu_isa_traits.hpp"
26 
27 namespace dnnl {
28 namespace impl {
29 namespace cpu {
30 namespace x64 {
31 
32 /* convolution */
33 enum conv_version_t {
34     ver_unused,
35     ver_fma,
36     ver_avx512_core,
37     ver_4fma,
38     ver_vnni
39 };
40 enum conv_loop_order_t {
41     loop_cgn,
42     loop_gnc,
43     loop_ngc,
44     loop_gncw,
45     loop_cwgn,
46     loop_ngcw,
47     loop_nhwcg,
48     loop_nwcg
49 };
50 enum conv_1x1_loop_order_t {
51     loop_rbl,
52     loop_rlb,
53     loop_lbr,
54     loop_lrb,
55     loop_blr,
56     loop_brl
57 };
58 
59 enum conv_kernel_kind_t { embd_bcast, expl_bcast };
60 enum conv_harness_t {
61     harness_2d_reduction,
62     harness_3d_reduction,
63     harness_mb_reduction,
64     harness_compute_full_spatial,
65     harness_nxc
66 };
67 
68 enum {
69     FLAG_MB_FIRST = 1 << 0,
70     FLAG_MB_LAST = 1 << 1,
71     FLAG_OC_FIRST = 1 << 2,
72     FLAG_OC_LAST = 1 << 3,
73     FLAG_IC_FIRST = 1 << 4,
74     FLAG_IC_LAST = 1 << 5,
75     FLAG_SP_FIRST = 1 << 6,
76     FLAG_SP_LAST = 1 << 7,
77     FLAG_REDUCE_FIRST = 1 << 8,
78     FLAG_REDUCE_LAST = 1 << 9,
79     FLAG_ZERO_FILTER = 1 << 0, /* Controls whether the inner kernel skips
80                                    loading weights-data from memory; this
81                                    needs to happen on the first Group/16
82                                    iteration. */
83     FLAG_ZERO_BIAS = 1 << 1, /* Controls whether the inner kernel skip
84                                loading bias data from memory */
85     FLAG_COMPUTE_BIAS = 1 << 2, /* Controls bias computation during execution
86                                     pass */
87 };
88 
89 enum class jit_memory_tag_kind_t { ncsp, nspc, blocked, undef };
90 
91 struct jit_conv_conf_t {
92     prop_kind_t prop_kind;
93     conv_version_t ver;
94     conv_loop_order_t loop_order;
95     conv_harness_t harness;
96 
97     int simd_w;
98     int ndims;
99     int mb;
100     int ngroups, ic, oc, oc_without_padding, ic_without_padding;
101     int id, ih, iw, od, oh, ow;
102     int f_pad, l_pad, t_pad;
103     int back_pad, r_pad, b_pad;
104     int kd, kh, kw;
105     int stride_d, stride_h, stride_w;
106     int dilate_d, dilate_h, dilate_w;
107     format_tag_t src_tag, wei_tag, dst_tag; // temporary workaround
108     bool with_bias;
109     bool with_sum;
110     bool with_eltwise;
111     bool with_binary;
112 
113     data_type_t sum_dt;
114 
115     bool with_binary_per_oc_bcast;
116     bool with_binary_no_bcast;
117 
118     bool is_fused_conv;
119     int dw_conv_buffer_oc;
120 
121     post_ops_t::entry_t::eltwise_t eltwise;
122     post_ops_t post_ops;
123     bool is_fast_postops; // maybe skip injector for sum and/or relu
124 
125     int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b, nthr_oh;
126 
127     int idp, ihp, iwp, ohp, owp, icp;
128     int nb_ic, ic_block;
129     int nb_oc, oc_block;
130     int nb_iw, iw_block;
131     int nb_ow, ow_block;
132     int nb_oc_blocking; /* used in jit kernels for nb_oc work blocking taking
133                            into account vector registers distribution */
134     int nb_oc_blocking_thr_chunk; /* used for distribution of nb_oc work
135                                       within threads */
136     int nb_ic_blocking, nb_ic_blocking_max; // blocking of nb_ic work
137     int nb_ic_L2;
138     int h_blocking;
139     int nb_oc_L2;
140     int ic_tail, oc_tail, ch_tail;
141     int ur_h, ur_w;
142     int ur_w_tail, ur_w_blocks;
143     int ur_ic, ur_kw;
144     bool is_1stconv;
145     int nonblk_group_off;
146     /* fma avx512_core */
147     conv_kernel_kind_t kernel_kind;
148     /* 4fma */
149     int tr_iw, tr_ih;
150     int tr_kw, tr_kh;
151     int tr_src_num_guard_elems;
152 
153     // Transpose buffer management
154     size_t tr_src_buf_size, tr_src_buf_count;
155     size_t tr_diff_dst_buf_size, tr_diff_dst_buf_count;
156     int nthr_mb_work;
157 
158     /* 1st conv: 4fma */
159     int tr_ld;
160     int kh_step;
161     /* 4vnni */
162     int typesize_in;
163     int typesize_out;
164     int typesize_bia;
165     int typesize_acc;
166     /* avx512_u8s8u8 */
167     int ic_nb1, ic_nb2;
168     int oc_nb1;
169     int ur_ow_max, ur_ow, ur_ow_tail;
170     int ur_ow_nsteps;
171     data_type_t bia_dt;
172     /* bf16 data-type for output */
173     data_type_t dst_dt;
174     data_type_t src_dt;
175     /* bf16 weights update */
176     data_type_t wei_dt;
177     data_type_t ddst_dt;
178     data_type_t dsrc_dt;
179     data_type_t dwei_dt;
180     bool expl_bcast;
181     bool large_spatial, large_w_filter;
182     int is_ic_scale, is_oc_scale;
183     int max_regs_ur; // maximum accumulation registers
184     // dw conv
185     int nb_ch, ch_block, nb_ch_blocking;
186     bool is_depthwise, is_fast_depthwise, is_resrc_depthwise;
187     int aligned_threads;
188     // large spatial
189     int ih_blk_size, oh_blk_size;
190     // s8s8 convolution
191     bool signed_input;
192     bool need_saturation;
193     float wei_adj_scale;
194     // zero-point compensation
195     bool src_zero_point;
196     int zp_pbuff_size;
197     bool dst_zero_point;
198     bool zp_src_is_common; // common, otherwise (TODO) per-channel
199     bool req_zero_point_buffer; // used for calculating padding compensation
200     bool zp_pbuff_outer_compute; // indicates if zp_bbuff is computed in
201     // a separate parallel region
202     int ow_pad, oh_pad, od_pad; // output elements with padding & filter overlap
203 
204     //output elements requiring zero-point padding compensation
205     int f_pad_output, back_pad_output;
206     int t_pad_output, b_pad_output;
207     int l_pad_output, r_pad_output;
208     // The number of output blocks corresponding to {l_pad, no_pad, r_pad}
209     int l_pad_blk, no_pad_w_blk, r_pad_blk;
210 
211     bool od_mid, oh_mid, ow_mid; // indicate if there is overlap between the
212             //width and height padded regions
213 
214     size_t h_blk_limits[5]; // pre-computed limits for output height block
215 
216     bool uses_permw_transposition;
217     bool transpose_src;
218     bool transpose_dst;
219     int ic_block_step;
220 
221     cpu_isa_t isa;
222     // bf16 bwdw conv
223     int tr_ow;
224     bool is_hw_transp; // spatial dim height-width transposed
225     int spatial_blk_size; // Height/depth block size inside the driver
226     bool global_transpose; // diff_dst & src tensors are transposed in one go
227     bool use_nt_stores_ddst; // Use non temporal stores in diff_dst transform
228 
229     // Needed for Intel(R) Advanced Matrix Extensions (Intel(R) AMX) kernels
230     bool is_nspc; // activations in nwc, nhwc, or ndhwc layout
231     bool is_relo; // reduced lowering optimization
232     int nreduce; // used with is_relo
233     bool is_pbuffer_strided; // does pbuffer have strided sectors?
234     int n_stride_sets; // number of stride sectors (or sets) in pbuffer
235     int kw_step; // usually stride_w, unless !is_pbuffer_strided
236     int kw_per_tile; // mostly for 1st convs
237     // The suffix _int refers to the block sizes of the src and diff_dst tiles,
238     // as opposed to the vector registers. This distinction is needed due to
239     // support for blocked layout (ie nChw16c) with bf16 data type.
240     int ic_block_int, ic_block_int_np, oc_block_int;
241     int nb_ic_int, nb_oc_int;
242     int nb_ih_blocking, nb_oh_blocking;
243 
244     int full_tile_width;
245     int max_tiles;
246     int tile_width;
247     int tile_tail;
248     int oh_per_tile;
249     int iw_blocks, ow_blocks;
250 
251     int per_one_pstore;
252 
253     size_t inp_buffer_size;
254     size_t wei_buffer_size;
255     size_t wsp_buffer_size;
256 
257     int nb_os;
258     int nb_os_blocking;
259     int nb_os2_blocking;
260     int os_tail;
261     int os_blocked;
262     int max_width;
263 
264     bool transform_to_vnni;
265 };
266 
267 // calculates filter size taking into account dilation
calculate_extended_filter_size(int filter_size,int dilation)268 inline int calculate_extended_filter_size(int filter_size, int dilation) {
269     return (filter_size - 1) * (dilation + 1) + 1;
270 }
271 
calculate_end_padding(int start_padding,int dst_size,int src_size,int spatial_stride,int dilated_filter_size)272 inline int calculate_end_padding(int start_padding, int dst_size, int src_size,
273         int spatial_stride, int dilated_filter_size) {
274     return (dst_size - 1) * spatial_stride + dilated_filter_size
275             - (src_size + start_padding);
276 }
277 
init_tag(format_tag_t & tag,const memory_desc_wrapper & mdw,const format_tag_t & tag_value)278 inline status_t init_tag(format_tag_t &tag, const memory_desc_wrapper &mdw,
279         const format_tag_t &tag_value) {
280     if (mdw.format_kind() == format_kind::any) return status::unimplemented;
281 
282     tag = mdw.matches_one_of_tag(tag_value);
283     return tag == tag_value ? status::success : status::unimplemented;
284 }
285 
286 struct jit_conv_conf_2x3_wino_t {
287     conv_version_t ver;
288 
289     int m;
290     int r;
291     int alpha;
292     int tile_h, tile_w;
293 
294     int mb;
295     int ngroups, ic, oc, oc_without_padding;
296     int ih, iw, oh, ow;
297     int l_pad, t_pad;
298     int r_pad, b_pad;
299     int kh, kw;
300     int stride_h, stride_w;
301     int dilate_h, dilate_w;
302 
303     int nb_ic, ic_block;
304     int nb_oc, oc_block;
305 
306     int w_block_size, h_block_size;
307 
308     data_type_t bia_dt;
309     data_type_t dst_dt;
310 
311     int is_oc_scale;
312     int typesize_in;
313     int typesize_out;
314     int typesize_bia;
315     int typesize_acc;
316 
317     format_tag_t src_tag, dst_tag; // temporary workaround
318     bool with_bias;
319     bool small_mb;
320 
321     int xb, yb;
322     int inp_stride;
323     int out_stride;
324     int wei_stride;
325     int bia_stride;
326 
327     int M, N, K;
328     int m_block, n_block, k_block;
329     int n2_block, n_chunks;
330     int k2_block, k_chunks;
331 
332     int mb_block, nb_mb;
333 
334     size_t size_wino_src, size_wino_wei, size_wino_dst;
335 
336     int nthr;
337 };
338 
339 /*
340    Winograd sched policy:
341 
342    Computation Unit:
343    W: weights transform
344    S: src transform
345    D: dst transform
346    G: gemm
347 
348    Thread grouping by:
349    i: nb_ic
350    o: nb_oc
351    t: tile_block
352    e: element in tile
353 
354    Note: 'i' and 'o' are omitted if
355    i. not combined with t or
356    ii. with discrete transforms
357 
358    Current policies supported:
359 */
360 enum winograd_sched_t {
361     WSCHED_INVALID = 0,
362 
363     /* Forward & backward-data */
364     /* W_S_G_D implements discrete transforms */
365     WSCHED_DATA_W_S_G_D,
366     /* W_SGD implements tiled transforms s.t. GEMM could reuse data in L2*/
367     WSCHED_DATA_W_SGD,
368 
369     /* Backward-weights */
370     WSCHED_WEI_S_D_G_W,
371     WSCHED_WEI_SDGtWo,
372     WSCHED_WEI_S_D_Giot_W,
373     WSCHED_WEI_SDGt_W,
374 };
375 
376 struct jit_conv_winograd_conf_t : public jit_conv_conf_t {
377     int itiles;
378     int jtiles;
379     int ntiles;
380     int ic_simd_block = 16;
381     int tile_4fma_padding;
382     int tile_4fma;
383     int oc_simd_block = 16;
384     int oc_reg_block;
385     int ic_reg_block;
386     int tile_block;
387     int tile_block_ur;
388     int nb_tile_block_ur;
389 
390     bool double_buffering;
391     bool with_relu_postsum;
392     int zmm_start;
393     int nb_reg;
394 
395     int dimK;
396     int dimK_4fma;
397     int dimK_reg_block;
398     int dimK_block;
399     int dimK_nb_block;
400 
401     int dimM;
402     int dimM_reg_block;
403     int dimM_simd_block;
404     int dimM_block;
405     int dimM_nb_block;
406 
407     int dimN;
408     int dimN_reg_block;
409     int dimN_bcast_ur;
410     int dimN_block;
411     int dimN_nb_block;
412 
413     winograd_sched_t sched_policy;
414 };
415 
416 struct jit_conv_call_s {
417     const void *src; /* hack, non-const for backward_data */
418     const void *dst; /* hack, non-const for forward */
419     const void *filt; /* hack, non-const for backward_weights */
420     const void *bias; /* hack, non-const for backward_bias */
421     const void *src_prf;
422     const void *dst_prf;
423     const void *filt_prf;
424     const void *bias_prf;
425     const void *scales;
426     const void *acc_s32;
427     const void *compensation;
428     const int32_t *zp_compensation;
429     const int32_t *src_zero_point;
430     const int32_t *zero_point_pbuff;
431     const int32_t *dst_zero_point;
432     const void *tile_cfg;
433     const void *tile_cfg_tail;
434 
435     // ptr to table of void * elements that are pointers to
436     // post_op binary src1 tensors
437     const void *post_ops_binary_rhs_arg_vec;
438     // logical (# of elems) offset to the processed output channel
439     // (for broadcasting [1,OC,1,1])
440     size_t oc_l_off;
441     const void *dst_orig; // pointer to dst memory (no offset)
442 
443     size_t oc_l_off_prf;
444     const void *dst_orig_prf;
445 
446     size_t kd_offset;
447     size_t kd_offset_prf;
448     size_t kh_offset;
449     size_t kh_offset_prf;
450     size_t os_index_begin;
451     size_t os_index_begin_prf;
452     size_t os_index_end;
453     size_t os_index_end_prf;
454     size_t kd_padding;
455     size_t kd_padding_prf;
456     size_t kh_padding;
457     size_t kh_padding_prf;
458     size_t iwb;
459     size_t iwb_prf;
460     size_t owb;
461     size_t owb_prf;
462     size_t ohb;
463     size_t kw_padding;
464     size_t channel;
465     size_t channel_prf;
466     size_t ic_blocks;
467     size_t oc_blocks;
468     size_t ur_w;
469     size_t ur_str_w;
470     size_t ch_blocks;
471     size_t ch_blocks_prf;
472     size_t reduce_work;
473     size_t reduce_work_prf;
474     size_t load_work;
475     size_t load_work_prf;
476     size_t l_overflow;
477     size_t r_overflow;
478     size_t t_overflow;
479     size_t b_overflow;
480     size_t f_overflow;
481     size_t back_overflow;
482     size_t last_h;
483     size_t tail;
484     size_t current_iw;
485     size_t is_osb;
486     int flags;
487     int flags_prf;
488     int oc_flag;
489     size_t last_ic_block;
490     size_t last_oc_block;
491 };
492 
493 struct jit_deconv_call_s {
494     const void *src; /* hack, non-const for backward_data */
495     const void *dst; /* hack, non-const for forward */
496     const void *filt; /* hack, non-const for backward_weights */
497     const void *bias; /* hack, non-const for backward_bias */
498     const void *scales;
499     const void *compensation;
500     const int32_t *zp_src_pad_str_compensation;
501     const int32_t *zp_compensation;
502     const int32_t *src_zero_point;
503     const int32_t *dst_zero_point;
504 
505     /*
506      * ptr to table of void * elements that are pointers to post_op binary
507      * src1 tensors
508      */
509     const void *post_ops_binary_rhs_arg_vec;
510     const void *dst_orig; /* pointer to dst memory (no offset) */
511     /*
512      * logical (# of elems) offset to the processed output channel
513      * (for broadcasting [1,OC,1,1])
514      */
515     size_t oc_l_off;
516     size_t t_overflow;
517     size_t b_overflow;
518     size_t f_overflow;
519     size_t back_overflow;
520     size_t kh_padding;
521     size_t kd_padding;
522     size_t oc_blocks;
523 };
524 
525 struct jit_dw_conv_call_s {
526     const void *input;
527     const void *output;
528     const void *filter;
529     const void *bias;
530     size_t kh_count;
531     size_t oh_count;
532     size_t oh_index;
533     size_t filter_pad_off;
534     unsigned char
535             exec_flags; /* Flags passed by driver execution to inner kernel */
536 };
537 
538 struct jit_wino_transform_call_s {
539     size_t tile_block;
540     size_t tile_block_ur;
541     size_t nb_tile_block_ur;
542     size_t tile_count;
543     size_t tj;
544     size_t ti;
545     void *src;
546     void *dst;
547     void *Mw;
548     void *M;
549     void *T;
550     void *G;
551     void *bias;
552 };
553 
554 struct jit_1x1_conv_conf_t {
555     prop_kind_t prop_kind;
556     conv_version_t ver;
557 
558     int ndims;
559     int mb;
560     int ngroups, ic, oc, oc_without_padding, ic_without_padding;
561     int id, ih, iw, od, oh, ow;
562     int f_pad, t_pad, l_pad;
563     int kd, kh, kw;
564     int stride_d, stride_h, stride_w;
565     format_tag_t src_tag, wei_tag, dst_tag; // temporary workaround
566     bool with_bias;
567     bool with_sum;
568     bool with_eltwise;
569     bool with_binary;
570     bool with_dw_conv;
571 
572     post_ops_t post_ops;
573 
574     int is, os;
575     int ic_block, oc_block;
576 
577     int ur, ur_tail;
578 
579     int reduce_dim, reduce_block, nb_reduce, nb_reduce_blocking,
580             nb_reduce_blocking_max;
581     int load_dim, load_block, nb_load, nb_load_blocking, nb_load_blocking_max,
582             nb_load_chunk;
583     int bcast_dim, bcast_block, nb_bcast, nb_bcast_blocking,
584             nb_bcast_blocking_max;
585 
586     int reduce_loop_unroll, reduce_loop_bcast_step, reduce_loop_load_step;
587     int load_loop_load_step, load_loop_iter_step;
588     int bcast_loop_output_step, bcast_loop_output_substep;
589     int bcast_loop_bcast_step, bcast_loop_bcast_substep;
590     int fma_step;
591     int load_grp_count;
592     conv_1x1_loop_order_t loop_order;
593     bool use_vmovntps;
594     /* avx512 core */
595     bool expl_bcast;
596     /* 4vnni */
597     int typesize_in;
598     int typesize_out;
599     int typesize_bia;
600     int typesize_acc;
601     /* 4fma */
602     bool transpose_src;
603     int tr_is;
604     int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b;
605     int is_oc_scale;
606     data_type_t bia_dt;
607     data_type_t dst_dt;
608     data_type_t sum_dt;
609     bool signed_input;
610     float wei_adj_scale;
611     // zero-point compensation
612     bool src_zero_point;
613     bool dst_zero_point;
614     bool zp_src_is_common; // common, otherwise (TODO) per-channel
615 
616     cpu_isa_t isa;
617     bool uses_permw_transposition;
618 };
619 
620 struct jit_1x1_conv_call_s {
621     const void *bcast_data;
622     const void *load_data;
623     const void *output_data;
624     const void *bias_data; // used in forward and backward_weights only
625     const void *acc_s32;
626     const void *scales;
627     const void *compensation;
628     const void *store_buffer;
629     const int32_t *zp_compensation;
630     const int32_t *src_zero_point;
631     const int32_t *dst_zero_point;
632 
633     // ptr to table of void * elements that are pointers to
634     // post_op binary src1 tensors
635     const void *post_ops_binary_rhs_arg_vec;
636     // logical (# of elems) offset to the processed output channel
637     // (for broadcasting [1,OC,1,1])
638     size_t oc_l_off;
639     // logical (# of elems) offset to the processed pixel
640     // (for non-broadcasting policy)
641     size_t dst_l_off;
642     const void *dst_orig; // pointer to dst memory (no offset)
643 
644     size_t load_dim;
645     size_t bcast_dim;
646     size_t reduce_dim;
647 
648     size_t output_stride; // used in backward_weights only
649 
650     size_t first_last_flag;
651 };
652 
653 struct jit_pool_conf_t {
654     int ndims;
655     int mb, c, c_without_padding;
656     int id, ih, iw, od, oh, ow;
657     int stride_d, stride_h, stride_w;
658     int kd, kh, kw;
659     int f_pad, t_pad, l_pad;
660     alg_kind_t alg;
661     bool is_training;
662     bool pad_w_is_null;
663     bool is_backward;
664     bool simple_alg;
665     bool is_c_padded;
666     data_type_t ind_dt;
667 
668     int c_block, c_tail, nb_c;
669     int ur_bc, ur_bc_tail;
670     int ur_c, ur_c_tail;
671     int ur;
672     size_t tail[4];
673     bool safe_c_tail;
674     data_type_t src_dt;
675     data_type_t dst_dt;
676 
677     int dt_size;
678     bool is_bf16;
679     jit_memory_tag_kind_t tag_kind;
is_plaindnnl::impl::cpu::x64::jit_pool_conf_t680     bool is_plain() const {
681         return (tag_kind == jit_memory_tag_kind_t::ncsp
682                 || tag_kind == jit_memory_tag_kind_t::nspc);
683     }
684 
685     cpu_isa_t isa;
686     post_ops_t post_ops;
687     bool with_postops;
688     bool with_eltwise;
689     bool with_binary;
690     int nthr;
691 };
692 
693 struct jit_pool_call_s {
694     const void *src;
695     const void *dst;
696     const void *indices;
697     const void *src_prf;
698     const void *dst_prf;
699     const void *indices_prf;
700     const void *post_ops_binary_rhs_arg_vec;
701     const void *dst_orig;
702     size_t c_elem_off;
703     size_t zero_ih;
704     size_t zero_id;
705     const void *zero_ptr;
706     size_t kd_padding;
707     size_t kh_padding;
708     size_t kh_padding_shift;
709     size_t kd_padding_shift;
710     size_t kw_padding;
711     const void *init_value;
712     float ker_area_h;
713     size_t ur_bc; // contains number of channel blocks to processing
714     size_t b_c; // contains number of channel blocks already processed
715 };
716 
717 struct jit_resampling_conf_t {
718     unsigned ndims = 0;
719 
720     unsigned c = 0;
721     unsigned id = 0, ih = 0, iw = 0;
722     unsigned od = 0, oh = 0, ow = 0;
723 
724     unsigned stride_d = 0;
725     unsigned stride_h = 0;
726     unsigned stride_w = 0;
727     unsigned inner_stride = 0;
728 
729     // The linear algorithm is an approximation of the point
730     // value based on the limit values. For one dimension,
731     // the approximation is based on the line, for two
732     // dimensions it will be a rectangle, and for three
733     // dimensions it will be a cuboid. Therefore,
734     // the possible variants for the number of corners are 2, 4, 8.
735     unsigned number_of_corners = 0;
736 
737     bool is_data_size_bigger_than_L3 = false;
738     bool is_saturation_needed = false;
739     data_type_t src_data_type = data_type::undef;
740     data_type_t dst_data_type = data_type::undef;
741     size_t src_dt_size = 0;
742     size_t dst_dt_size = 0;
743     size_t output_data_size = 0;
744     size_t el_size_of_indices = 0;
745 
746     bool is_blocked_8_format = false;
747     format_tag_t src_tag = format_tag::undef;
748     jit_memory_tag_kind_t tag_kind = jit_memory_tag_kind_t::undef;
749     alg_kind_t alg = alg_kind::undef;
750 
751     cpu_isa_t isa = isa_any;
752 
753     post_ops_t post_ops = post_ops_t();
754     bool with_postops = false;
755     bool with_eltwise = false;
756     bool with_binary = false;
757     bool with_sum = false;
758     std::queue<float> sum_scales;
759 };
760 
761 struct jit_resampling_call_s {
762     size_t batch_of_sp_points_to_process = 0;
763 
764     const void *src = nullptr;
765     const void *dst = nullptr;
766     const void *indices = nullptr;
767     const void *weights = nullptr;
768     const void *post_ops_binary_rhs_arg_vec = nullptr;
769     const void *dst_orig = nullptr;
770 
771     size_t c_offset = 0;
772 
773     size_t src_offset_top = 0;
774     size_t src_offset_bottom = 0;
775     size_t src_offset_front = 0;
776     size_t src_offset_back = 0;
777 
778     float weight_top = 0.0f;
779     float weight_bottom = 0.0f;
780     float weight_front = 0.0f;
781     float weight_back = 0.0f;
782 };
783 
784 struct jit_brdgmm_conv_conf_t {
785 
786     int nthr;
787     int mb, ngroups, ic, oc;
788     int ih, iw, oh, ow;
789     int l_pad, r_pad, t_pad, b_pad;
790     int kh, kw;
791     int stride_h, stride_w;
792     int nb_ch, ch_block, chb_tail;
793     int nb_ch_blocking;
794     int ow_block, ow_tail, nb_ow;
795     // idx of jit kernel when mutiple jit kernels are used in a primitive.
796     int chb_tail_idx, ow_tail_idx, nb_ch_blocking_idx;
797     int adjusted_batch_size;
798 
799     bool with_bias;
800     bool with_post_ops;
801     bool is_oc_scale;
802 
803     data_type_t src_dt;
804     data_type_t wei_dt;
805     data_type_t bia_dt;
806     data_type_t dst_dt;
807 
808     brgemm_batch_kind_t batch_kind;
809 
810     size_t src_dsz;
811     size_t wei_dsz;
812     size_t bia_dsz;
813     size_t dst_dsz;
814 
815     cpu_isa_t isa;
816 };
817 
818 enum conv_brgemm_loop_order_t {
819     loop_ndhwgc,
820     loop_ngcdhw,
821 };
822 
823 enum conv_brgemm_exec_type_t {
824     exec_undefined = 0,
825     exec_base,
826     exec_trans,
827     exec_vpad,
828 };
829 
830 struct jit_brgemm_conv_conf_t {
831     cpu_isa_t isa;
832     prop_kind_t prop_kind;
833     conv_version_t ver;
834     conv_brgemm_loop_order_t loop_order;
835     conv_harness_t harness;
836     int simd_w, amx_w, amx_h;
837     int ndims;
838     int mb;
839     int ngroups, ic, oc, oc_without_padding, ic_without_padding;
840 
841     int od_block, oh_block, nb_od,
842             nb_oh; // blocking  - included in parallelization
843     dim_t inp_buffer_size, inp_buffer_mask_size;
844     conv_brgemm_exec_type_t exec_type;
845 
846     int id, ih, iw, od, oh, ow, os, idp, ihp, iwp, icp;
847     int f_pad, l_pad, t_pad;
848     int back_pad, r_pad, b_pad;
849     int kd, kh, kw;
850     int ext_kd, ext_kh, ext_kw;
851     int kd_block, kh_block, kw_block, kd_block_pad, kh_block_pad, kw_block_pad;
852     int stride_d, stride_h, stride_w;
853     int dilate_d, dilate_h, dilate_w;
854     format_tag_t src_tag, wei_tag, dst_tag; // temporary workaround
855     bool with_bias;
856     bool with_sum;
857     bool with_eltwise;
858     bool with_binary;
859 
860     bool is_fused_conv;
861     bool is_os_blocking;
862     bool is_rtus;
863     int nb_ic, ic_block;
864     int nb_oc, oc_block;
865     int nb_iw, iw_block;
866     int nb_ow, ow_block, ow_tail;
867     int nb_os, os_block;
868     int nb_oc_blocking;
869     int nb_ic_blocking;
870     int nb_os_blocking;
871 
872     data_type_t src_dt;
873     data_type_t dst_dt;
874     data_type_t wei_dt;
875     data_type_t acc_dt;
876     data_type_t bia_dt;
877     size_t src_dsz;
878     size_t wei_dsz;
879     size_t dst_dsz;
880     size_t acc_dsz;
881     size_t bia_dsz;
882 
883     bool use_buffer;
884     dim_t buffer_size;
885 
886     int is_oc_scale;
887 
888     int LDA, LDB, LDC, LDD;
889     int M, N, K, M_tail, N_tail, K_tail;
890     // M for brgemm kernel. For use_store_mask it is usually greater than M (M_tail). Otherwise it is equal to M (M_tail)
891     int brgM, brgM_tail;
892     int gemm_batch_size, adjusted_batch_size;
893     brgemm_batch_kind_t brg_type;
894     // strides for brg_type == brgemm_strd
895     dim_t brg_stride_a, brg_stride_b;
896     int nthr;
897 
898     int max_batch;
899     int max_vpad;
900 
901     bool wei_plain;
902     bool is_ic_padded;
903     int kw_sets, kh_sets;
904     bool copy_block_only;
905     bool amx_tile_load_xx;
906     int use_M_mask;
907     int oskip;
908     bool brgemm_bd_loop_innermost;
909 
910     bool use_uker;
911     bool use_interleave_stores;
912     bool is_1x1;
913 };
914 
915 struct jit_shuffle_conf_t {
916     unsigned ndims = 0;
917 
918     unsigned mb = 0, c = 0, d = 0, h = 0, w = 0, sp = 0;
919 
920     unsigned stride_mb = 0;
921     unsigned blk_size = 0;
922     unsigned group_size = 0;
923     unsigned axis = 0;
924     unsigned axis_size = 0;
925     unsigned simd_tail = 0;
926     unsigned simd_w = 0;
927 
928     jit_memory_tag_kind_t tag_kind = jit_memory_tag_kind_t::undef;
929     data_type_t data_type = data_type::undef;
930     size_t dt_size = 0;
931     unsigned el_size_of_indices = 0;
932     dim_t c_split_size = 0;
933     dim_t sp_split_size = 0;
934 
935     cpu_isa_t isa = isa_any;
936 };
937 
938 struct jit_shuffle_call_s {
939     const void *src = nullptr;
940     void *dst = nullptr;
941     const void *input_off_ptr = nullptr;
942 
943     dim_t cb_loop_size
944             = 0; // number of loop iterations over corresponding C batches
945     bool is_padded_block = false;
946 };
947 
948 enum class binary_op_t : unsigned { none, c_blocked, n_spatial_c, n_c_spatial };
949 
950 enum class binary_bcast_t : unsigned {
951     none, // tensor operation
952     scalar,
953     per_batch,
954     per_c,
955     per_w
956 };
957 
958 struct jit_binary_conf_t {
959     binary_op_t op_type = binary_op_t::none;
960     binary_bcast_t bcast_type = binary_bcast_t::none;
961     bool do_scale_src0 = false;
962     bool do_scale_src1 = false;
963     bool do_sum = false;
964     bool with_eltwise = false;
965     bool with_binary = false;
966     bool with_postops = false;
967     float sum_scale = 0.f;
968     bool use_stride_src1 = false;
969     bool broadcast_src1_value = false;
970     bool use_stride_rhs_postops = false;
971     bool postops_per_oc_broadcast_exists = false;
972     bool is_i8 = false;
973     bool is_bf16 = false;
974     bool is_src_different_layouts = false;
975     dim_t outer_dims = 1;
976     int src1_stride = 1;
977     int not_bcasted_sp_dims = 0;
978 
979     data_type_t src0_type = data_type::undef;
980     data_type_t src1_type = data_type::undef;
981     data_type_t dst_type = data_type::undef;
982 };
983 
984 struct jit_binary_call_s {
985     // keep all sizes at 8 bytes -- jit code expects this
986     const void *src0, *src1, *dst, *indices;
987     const float *scales_src0, *scales_src1;
988     size_t spat_offt_count;
989     const void *post_ops_binary_rhs_arg_vec;
990     size_t src1_stride_range;
991     const void *dst_orig;
992 };
993 
994 struct jit_reduction_conf_t {
995     data_type_t src_type = data_type::undef;
996     data_type_t dst_type = data_type::undef;
997     data_type_t acc_type = data_type::undef;
998 
999     std::size_t src_dt_size = 0;
1000     std::size_t dst_dt_size = 0;
1001     std::size_t acc_dt_size = 0;
1002 
1003     alg_kind_t alg = alg_kind::undef;
1004     cpu_isa_t isa = isa_any;
1005 
1006     dim_t idle_size = 0;
1007     dim_t reduce_size = 0;
1008 
1009     bool is_saturation_needed = false;
1010 };
1011 
1012 struct jit_reduction_call_s {
1013     const void *src = nullptr;
1014     void *dst = nullptr;
1015 };
1016 
1017 } // namespace x64
1018 } // namespace cpu
1019 } // namespace impl
1020 } // namespace dnnl
1021 
1022 #endif
1023