1 /*******************************************************************************
2 * Copyright 2016-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef CPU_X64_JIT_PRIMITIVE_CONF_HPP
18 #define CPU_X64_JIT_PRIMITIVE_CONF_HPP
19
20 #include <queue>
21 #include <stdint.h>
22
23 #include "common/primitive_attr.hpp"
24 #include "cpu/x64/brgemm/brgemm_types.hpp"
25 #include "cpu/x64/cpu_isa_traits.hpp"
26
27 namespace dnnl {
28 namespace impl {
29 namespace cpu {
30 namespace x64 {
31
32 /* convolution */
33 enum conv_version_t {
34 ver_unused,
35 ver_fma,
36 ver_avx512_core,
37 ver_4fma,
38 ver_vnni
39 };
40 enum conv_loop_order_t {
41 loop_cgn,
42 loop_gnc,
43 loop_ngc,
44 loop_gncw,
45 loop_cwgn,
46 loop_ngcw,
47 loop_nhwcg,
48 loop_nwcg
49 };
50 enum conv_1x1_loop_order_t {
51 loop_rbl,
52 loop_rlb,
53 loop_lbr,
54 loop_lrb,
55 loop_blr,
56 loop_brl
57 };
58
59 enum conv_kernel_kind_t { embd_bcast, expl_bcast };
60 enum conv_harness_t {
61 harness_2d_reduction,
62 harness_3d_reduction,
63 harness_mb_reduction,
64 harness_compute_full_spatial,
65 harness_nxc
66 };
67
68 enum {
69 FLAG_MB_FIRST = 1 << 0,
70 FLAG_MB_LAST = 1 << 1,
71 FLAG_OC_FIRST = 1 << 2,
72 FLAG_OC_LAST = 1 << 3,
73 FLAG_IC_FIRST = 1 << 4,
74 FLAG_IC_LAST = 1 << 5,
75 FLAG_SP_FIRST = 1 << 6,
76 FLAG_SP_LAST = 1 << 7,
77 FLAG_REDUCE_FIRST = 1 << 8,
78 FLAG_REDUCE_LAST = 1 << 9,
79 FLAG_ZERO_FILTER = 1 << 0, /* Controls whether the inner kernel skips
80 loading weights-data from memory; this
81 needs to happen on the first Group/16
82 iteration. */
83 FLAG_ZERO_BIAS = 1 << 1, /* Controls whether the inner kernel skip
84 loading bias data from memory */
85 FLAG_COMPUTE_BIAS = 1 << 2, /* Controls bias computation during execution
86 pass */
87 };
88
89 enum class jit_memory_tag_kind_t { ncsp, nspc, blocked, undef };
90
91 struct jit_conv_conf_t {
92 prop_kind_t prop_kind;
93 conv_version_t ver;
94 conv_loop_order_t loop_order;
95 conv_harness_t harness;
96
97 int simd_w;
98 int ndims;
99 int mb;
100 int ngroups, ic, oc, oc_without_padding, ic_without_padding;
101 int id, ih, iw, od, oh, ow;
102 int f_pad, l_pad, t_pad;
103 int back_pad, r_pad, b_pad;
104 int kd, kh, kw;
105 int stride_d, stride_h, stride_w;
106 int dilate_d, dilate_h, dilate_w;
107 format_tag_t src_tag, wei_tag, dst_tag; // temporary workaround
108 bool with_bias;
109 bool with_sum;
110 bool with_eltwise;
111 bool with_binary;
112
113 data_type_t sum_dt;
114
115 bool with_binary_per_oc_bcast;
116 bool with_binary_no_bcast;
117
118 bool is_fused_conv;
119 int dw_conv_buffer_oc;
120
121 post_ops_t::entry_t::eltwise_t eltwise;
122 post_ops_t post_ops;
123 bool is_fast_postops; // maybe skip injector for sum and/or relu
124
125 int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b, nthr_oh;
126
127 int idp, ihp, iwp, ohp, owp, icp;
128 int nb_ic, ic_block;
129 int nb_oc, oc_block;
130 int nb_iw, iw_block;
131 int nb_ow, ow_block;
132 int nb_oc_blocking; /* used in jit kernels for nb_oc work blocking taking
133 into account vector registers distribution */
134 int nb_oc_blocking_thr_chunk; /* used for distribution of nb_oc work
135 within threads */
136 int nb_ic_blocking, nb_ic_blocking_max; // blocking of nb_ic work
137 int nb_ic_L2;
138 int h_blocking;
139 int nb_oc_L2;
140 int ic_tail, oc_tail, ch_tail;
141 int ur_h, ur_w;
142 int ur_w_tail, ur_w_blocks;
143 int ur_ic, ur_kw;
144 bool is_1stconv;
145 int nonblk_group_off;
146 /* fma avx512_core */
147 conv_kernel_kind_t kernel_kind;
148 /* 4fma */
149 int tr_iw, tr_ih;
150 int tr_kw, tr_kh;
151 int tr_src_num_guard_elems;
152
153 // Transpose buffer management
154 size_t tr_src_buf_size, tr_src_buf_count;
155 size_t tr_diff_dst_buf_size, tr_diff_dst_buf_count;
156 int nthr_mb_work;
157
158 /* 1st conv: 4fma */
159 int tr_ld;
160 int kh_step;
161 /* 4vnni */
162 int typesize_in;
163 int typesize_out;
164 int typesize_bia;
165 int typesize_acc;
166 /* avx512_u8s8u8 */
167 int ic_nb1, ic_nb2;
168 int oc_nb1;
169 int ur_ow_max, ur_ow, ur_ow_tail;
170 int ur_ow_nsteps;
171 data_type_t bia_dt;
172 /* bf16 data-type for output */
173 data_type_t dst_dt;
174 data_type_t src_dt;
175 /* bf16 weights update */
176 data_type_t wei_dt;
177 data_type_t ddst_dt;
178 data_type_t dsrc_dt;
179 data_type_t dwei_dt;
180 bool expl_bcast;
181 bool large_spatial, large_w_filter;
182 int is_ic_scale, is_oc_scale;
183 int max_regs_ur; // maximum accumulation registers
184 // dw conv
185 int nb_ch, ch_block, nb_ch_blocking;
186 bool is_depthwise, is_fast_depthwise, is_resrc_depthwise;
187 int aligned_threads;
188 // large spatial
189 int ih_blk_size, oh_blk_size;
190 // s8s8 convolution
191 bool signed_input;
192 bool need_saturation;
193 float wei_adj_scale;
194 // zero-point compensation
195 bool src_zero_point;
196 int zp_pbuff_size;
197 bool dst_zero_point;
198 bool zp_src_is_common; // common, otherwise (TODO) per-channel
199 bool req_zero_point_buffer; // used for calculating padding compensation
200 bool zp_pbuff_outer_compute; // indicates if zp_bbuff is computed in
201 // a separate parallel region
202 int ow_pad, oh_pad, od_pad; // output elements with padding & filter overlap
203
204 //output elements requiring zero-point padding compensation
205 int f_pad_output, back_pad_output;
206 int t_pad_output, b_pad_output;
207 int l_pad_output, r_pad_output;
208 // The number of output blocks corresponding to {l_pad, no_pad, r_pad}
209 int l_pad_blk, no_pad_w_blk, r_pad_blk;
210
211 bool od_mid, oh_mid, ow_mid; // indicate if there is overlap between the
212 //width and height padded regions
213
214 size_t h_blk_limits[5]; // pre-computed limits for output height block
215
216 bool uses_permw_transposition;
217 bool transpose_src;
218 bool transpose_dst;
219 int ic_block_step;
220
221 cpu_isa_t isa;
222 // bf16 bwdw conv
223 int tr_ow;
224 bool is_hw_transp; // spatial dim height-width transposed
225 int spatial_blk_size; // Height/depth block size inside the driver
226 bool global_transpose; // diff_dst & src tensors are transposed in one go
227 bool use_nt_stores_ddst; // Use non temporal stores in diff_dst transform
228
229 // Needed for Intel(R) Advanced Matrix Extensions (Intel(R) AMX) kernels
230 bool is_nspc; // activations in nwc, nhwc, or ndhwc layout
231 bool is_relo; // reduced lowering optimization
232 int nreduce; // used with is_relo
233 bool is_pbuffer_strided; // does pbuffer have strided sectors?
234 int n_stride_sets; // number of stride sectors (or sets) in pbuffer
235 int kw_step; // usually stride_w, unless !is_pbuffer_strided
236 int kw_per_tile; // mostly for 1st convs
237 // The suffix _int refers to the block sizes of the src and diff_dst tiles,
238 // as opposed to the vector registers. This distinction is needed due to
239 // support for blocked layout (ie nChw16c) with bf16 data type.
240 int ic_block_int, ic_block_int_np, oc_block_int;
241 int nb_ic_int, nb_oc_int;
242 int nb_ih_blocking, nb_oh_blocking;
243
244 int full_tile_width;
245 int max_tiles;
246 int tile_width;
247 int tile_tail;
248 int oh_per_tile;
249 int iw_blocks, ow_blocks;
250
251 int per_one_pstore;
252
253 size_t inp_buffer_size;
254 size_t wei_buffer_size;
255 size_t wsp_buffer_size;
256
257 int nb_os;
258 int nb_os_blocking;
259 int nb_os2_blocking;
260 int os_tail;
261 int os_blocked;
262 int max_width;
263
264 bool transform_to_vnni;
265 };
266
267 // calculates filter size taking into account dilation
calculate_extended_filter_size(int filter_size,int dilation)268 inline int calculate_extended_filter_size(int filter_size, int dilation) {
269 return (filter_size - 1) * (dilation + 1) + 1;
270 }
271
calculate_end_padding(int start_padding,int dst_size,int src_size,int spatial_stride,int dilated_filter_size)272 inline int calculate_end_padding(int start_padding, int dst_size, int src_size,
273 int spatial_stride, int dilated_filter_size) {
274 return (dst_size - 1) * spatial_stride + dilated_filter_size
275 - (src_size + start_padding);
276 }
277
init_tag(format_tag_t & tag,const memory_desc_wrapper & mdw,const format_tag_t & tag_value)278 inline status_t init_tag(format_tag_t &tag, const memory_desc_wrapper &mdw,
279 const format_tag_t &tag_value) {
280 if (mdw.format_kind() == format_kind::any) return status::unimplemented;
281
282 tag = mdw.matches_one_of_tag(tag_value);
283 return tag == tag_value ? status::success : status::unimplemented;
284 }
285
286 struct jit_conv_conf_2x3_wino_t {
287 conv_version_t ver;
288
289 int m;
290 int r;
291 int alpha;
292 int tile_h, tile_w;
293
294 int mb;
295 int ngroups, ic, oc, oc_without_padding;
296 int ih, iw, oh, ow;
297 int l_pad, t_pad;
298 int r_pad, b_pad;
299 int kh, kw;
300 int stride_h, stride_w;
301 int dilate_h, dilate_w;
302
303 int nb_ic, ic_block;
304 int nb_oc, oc_block;
305
306 int w_block_size, h_block_size;
307
308 data_type_t bia_dt;
309 data_type_t dst_dt;
310
311 int is_oc_scale;
312 int typesize_in;
313 int typesize_out;
314 int typesize_bia;
315 int typesize_acc;
316
317 format_tag_t src_tag, dst_tag; // temporary workaround
318 bool with_bias;
319 bool small_mb;
320
321 int xb, yb;
322 int inp_stride;
323 int out_stride;
324 int wei_stride;
325 int bia_stride;
326
327 int M, N, K;
328 int m_block, n_block, k_block;
329 int n2_block, n_chunks;
330 int k2_block, k_chunks;
331
332 int mb_block, nb_mb;
333
334 size_t size_wino_src, size_wino_wei, size_wino_dst;
335
336 int nthr;
337 };
338
339 /*
340 Winograd sched policy:
341
342 Computation Unit:
343 W: weights transform
344 S: src transform
345 D: dst transform
346 G: gemm
347
348 Thread grouping by:
349 i: nb_ic
350 o: nb_oc
351 t: tile_block
352 e: element in tile
353
354 Note: 'i' and 'o' are omitted if
355 i. not combined with t or
356 ii. with discrete transforms
357
358 Current policies supported:
359 */
360 enum winograd_sched_t {
361 WSCHED_INVALID = 0,
362
363 /* Forward & backward-data */
364 /* W_S_G_D implements discrete transforms */
365 WSCHED_DATA_W_S_G_D,
366 /* W_SGD implements tiled transforms s.t. GEMM could reuse data in L2*/
367 WSCHED_DATA_W_SGD,
368
369 /* Backward-weights */
370 WSCHED_WEI_S_D_G_W,
371 WSCHED_WEI_SDGtWo,
372 WSCHED_WEI_S_D_Giot_W,
373 WSCHED_WEI_SDGt_W,
374 };
375
376 struct jit_conv_winograd_conf_t : public jit_conv_conf_t {
377 int itiles;
378 int jtiles;
379 int ntiles;
380 int ic_simd_block = 16;
381 int tile_4fma_padding;
382 int tile_4fma;
383 int oc_simd_block = 16;
384 int oc_reg_block;
385 int ic_reg_block;
386 int tile_block;
387 int tile_block_ur;
388 int nb_tile_block_ur;
389
390 bool double_buffering;
391 bool with_relu_postsum;
392 int zmm_start;
393 int nb_reg;
394
395 int dimK;
396 int dimK_4fma;
397 int dimK_reg_block;
398 int dimK_block;
399 int dimK_nb_block;
400
401 int dimM;
402 int dimM_reg_block;
403 int dimM_simd_block;
404 int dimM_block;
405 int dimM_nb_block;
406
407 int dimN;
408 int dimN_reg_block;
409 int dimN_bcast_ur;
410 int dimN_block;
411 int dimN_nb_block;
412
413 winograd_sched_t sched_policy;
414 };
415
416 struct jit_conv_call_s {
417 const void *src; /* hack, non-const for backward_data */
418 const void *dst; /* hack, non-const for forward */
419 const void *filt; /* hack, non-const for backward_weights */
420 const void *bias; /* hack, non-const for backward_bias */
421 const void *src_prf;
422 const void *dst_prf;
423 const void *filt_prf;
424 const void *bias_prf;
425 const void *scales;
426 const void *acc_s32;
427 const void *compensation;
428 const int32_t *zp_compensation;
429 const int32_t *src_zero_point;
430 const int32_t *zero_point_pbuff;
431 const int32_t *dst_zero_point;
432 const void *tile_cfg;
433 const void *tile_cfg_tail;
434
435 // ptr to table of void * elements that are pointers to
436 // post_op binary src1 tensors
437 const void *post_ops_binary_rhs_arg_vec;
438 // logical (# of elems) offset to the processed output channel
439 // (for broadcasting [1,OC,1,1])
440 size_t oc_l_off;
441 const void *dst_orig; // pointer to dst memory (no offset)
442
443 size_t oc_l_off_prf;
444 const void *dst_orig_prf;
445
446 size_t kd_offset;
447 size_t kd_offset_prf;
448 size_t kh_offset;
449 size_t kh_offset_prf;
450 size_t os_index_begin;
451 size_t os_index_begin_prf;
452 size_t os_index_end;
453 size_t os_index_end_prf;
454 size_t kd_padding;
455 size_t kd_padding_prf;
456 size_t kh_padding;
457 size_t kh_padding_prf;
458 size_t iwb;
459 size_t iwb_prf;
460 size_t owb;
461 size_t owb_prf;
462 size_t ohb;
463 size_t kw_padding;
464 size_t channel;
465 size_t channel_prf;
466 size_t ic_blocks;
467 size_t oc_blocks;
468 size_t ur_w;
469 size_t ur_str_w;
470 size_t ch_blocks;
471 size_t ch_blocks_prf;
472 size_t reduce_work;
473 size_t reduce_work_prf;
474 size_t load_work;
475 size_t load_work_prf;
476 size_t l_overflow;
477 size_t r_overflow;
478 size_t t_overflow;
479 size_t b_overflow;
480 size_t f_overflow;
481 size_t back_overflow;
482 size_t last_h;
483 size_t tail;
484 size_t current_iw;
485 size_t is_osb;
486 int flags;
487 int flags_prf;
488 int oc_flag;
489 size_t last_ic_block;
490 size_t last_oc_block;
491 };
492
493 struct jit_deconv_call_s {
494 const void *src; /* hack, non-const for backward_data */
495 const void *dst; /* hack, non-const for forward */
496 const void *filt; /* hack, non-const for backward_weights */
497 const void *bias; /* hack, non-const for backward_bias */
498 const void *scales;
499 const void *compensation;
500 const int32_t *zp_src_pad_str_compensation;
501 const int32_t *zp_compensation;
502 const int32_t *src_zero_point;
503 const int32_t *dst_zero_point;
504
505 /*
506 * ptr to table of void * elements that are pointers to post_op binary
507 * src1 tensors
508 */
509 const void *post_ops_binary_rhs_arg_vec;
510 const void *dst_orig; /* pointer to dst memory (no offset) */
511 /*
512 * logical (# of elems) offset to the processed output channel
513 * (for broadcasting [1,OC,1,1])
514 */
515 size_t oc_l_off;
516 size_t t_overflow;
517 size_t b_overflow;
518 size_t f_overflow;
519 size_t back_overflow;
520 size_t kh_padding;
521 size_t kd_padding;
522 size_t oc_blocks;
523 };
524
525 struct jit_dw_conv_call_s {
526 const void *input;
527 const void *output;
528 const void *filter;
529 const void *bias;
530 size_t kh_count;
531 size_t oh_count;
532 size_t oh_index;
533 size_t filter_pad_off;
534 unsigned char
535 exec_flags; /* Flags passed by driver execution to inner kernel */
536 };
537
538 struct jit_wino_transform_call_s {
539 size_t tile_block;
540 size_t tile_block_ur;
541 size_t nb_tile_block_ur;
542 size_t tile_count;
543 size_t tj;
544 size_t ti;
545 void *src;
546 void *dst;
547 void *Mw;
548 void *M;
549 void *T;
550 void *G;
551 void *bias;
552 };
553
554 struct jit_1x1_conv_conf_t {
555 prop_kind_t prop_kind;
556 conv_version_t ver;
557
558 int ndims;
559 int mb;
560 int ngroups, ic, oc, oc_without_padding, ic_without_padding;
561 int id, ih, iw, od, oh, ow;
562 int f_pad, t_pad, l_pad;
563 int kd, kh, kw;
564 int stride_d, stride_h, stride_w;
565 format_tag_t src_tag, wei_tag, dst_tag; // temporary workaround
566 bool with_bias;
567 bool with_sum;
568 bool with_eltwise;
569 bool with_binary;
570 bool with_dw_conv;
571
572 post_ops_t post_ops;
573
574 int is, os;
575 int ic_block, oc_block;
576
577 int ur, ur_tail;
578
579 int reduce_dim, reduce_block, nb_reduce, nb_reduce_blocking,
580 nb_reduce_blocking_max;
581 int load_dim, load_block, nb_load, nb_load_blocking, nb_load_blocking_max,
582 nb_load_chunk;
583 int bcast_dim, bcast_block, nb_bcast, nb_bcast_blocking,
584 nb_bcast_blocking_max;
585
586 int reduce_loop_unroll, reduce_loop_bcast_step, reduce_loop_load_step;
587 int load_loop_load_step, load_loop_iter_step;
588 int bcast_loop_output_step, bcast_loop_output_substep;
589 int bcast_loop_bcast_step, bcast_loop_bcast_substep;
590 int fma_step;
591 int load_grp_count;
592 conv_1x1_loop_order_t loop_order;
593 bool use_vmovntps;
594 /* avx512 core */
595 bool expl_bcast;
596 /* 4vnni */
597 int typesize_in;
598 int typesize_out;
599 int typesize_bia;
600 int typesize_acc;
601 /* 4fma */
602 bool transpose_src;
603 int tr_is;
604 int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b;
605 int is_oc_scale;
606 data_type_t bia_dt;
607 data_type_t dst_dt;
608 data_type_t sum_dt;
609 bool signed_input;
610 float wei_adj_scale;
611 // zero-point compensation
612 bool src_zero_point;
613 bool dst_zero_point;
614 bool zp_src_is_common; // common, otherwise (TODO) per-channel
615
616 cpu_isa_t isa;
617 bool uses_permw_transposition;
618 };
619
620 struct jit_1x1_conv_call_s {
621 const void *bcast_data;
622 const void *load_data;
623 const void *output_data;
624 const void *bias_data; // used in forward and backward_weights only
625 const void *acc_s32;
626 const void *scales;
627 const void *compensation;
628 const void *store_buffer;
629 const int32_t *zp_compensation;
630 const int32_t *src_zero_point;
631 const int32_t *dst_zero_point;
632
633 // ptr to table of void * elements that are pointers to
634 // post_op binary src1 tensors
635 const void *post_ops_binary_rhs_arg_vec;
636 // logical (# of elems) offset to the processed output channel
637 // (for broadcasting [1,OC,1,1])
638 size_t oc_l_off;
639 // logical (# of elems) offset to the processed pixel
640 // (for non-broadcasting policy)
641 size_t dst_l_off;
642 const void *dst_orig; // pointer to dst memory (no offset)
643
644 size_t load_dim;
645 size_t bcast_dim;
646 size_t reduce_dim;
647
648 size_t output_stride; // used in backward_weights only
649
650 size_t first_last_flag;
651 };
652
653 struct jit_pool_conf_t {
654 int ndims;
655 int mb, c, c_without_padding;
656 int id, ih, iw, od, oh, ow;
657 int stride_d, stride_h, stride_w;
658 int kd, kh, kw;
659 int f_pad, t_pad, l_pad;
660 alg_kind_t alg;
661 bool is_training;
662 bool pad_w_is_null;
663 bool is_backward;
664 bool simple_alg;
665 bool is_c_padded;
666 data_type_t ind_dt;
667
668 int c_block, c_tail, nb_c;
669 int ur_bc, ur_bc_tail;
670 int ur_c, ur_c_tail;
671 int ur;
672 size_t tail[4];
673 bool safe_c_tail;
674 data_type_t src_dt;
675 data_type_t dst_dt;
676
677 int dt_size;
678 bool is_bf16;
679 jit_memory_tag_kind_t tag_kind;
is_plaindnnl::impl::cpu::x64::jit_pool_conf_t680 bool is_plain() const {
681 return (tag_kind == jit_memory_tag_kind_t::ncsp
682 || tag_kind == jit_memory_tag_kind_t::nspc);
683 }
684
685 cpu_isa_t isa;
686 post_ops_t post_ops;
687 bool with_postops;
688 bool with_eltwise;
689 bool with_binary;
690 int nthr;
691 };
692
693 struct jit_pool_call_s {
694 const void *src;
695 const void *dst;
696 const void *indices;
697 const void *src_prf;
698 const void *dst_prf;
699 const void *indices_prf;
700 const void *post_ops_binary_rhs_arg_vec;
701 const void *dst_orig;
702 size_t c_elem_off;
703 size_t zero_ih;
704 size_t zero_id;
705 const void *zero_ptr;
706 size_t kd_padding;
707 size_t kh_padding;
708 size_t kh_padding_shift;
709 size_t kd_padding_shift;
710 size_t kw_padding;
711 const void *init_value;
712 float ker_area_h;
713 size_t ur_bc; // contains number of channel blocks to processing
714 size_t b_c; // contains number of channel blocks already processed
715 };
716
717 struct jit_resampling_conf_t {
718 unsigned ndims = 0;
719
720 unsigned c = 0;
721 unsigned id = 0, ih = 0, iw = 0;
722 unsigned od = 0, oh = 0, ow = 0;
723
724 unsigned stride_d = 0;
725 unsigned stride_h = 0;
726 unsigned stride_w = 0;
727 unsigned inner_stride = 0;
728
729 // The linear algorithm is an approximation of the point
730 // value based on the limit values. For one dimension,
731 // the approximation is based on the line, for two
732 // dimensions it will be a rectangle, and for three
733 // dimensions it will be a cuboid. Therefore,
734 // the possible variants for the number of corners are 2, 4, 8.
735 unsigned number_of_corners = 0;
736
737 bool is_data_size_bigger_than_L3 = false;
738 bool is_saturation_needed = false;
739 data_type_t src_data_type = data_type::undef;
740 data_type_t dst_data_type = data_type::undef;
741 size_t src_dt_size = 0;
742 size_t dst_dt_size = 0;
743 size_t output_data_size = 0;
744 size_t el_size_of_indices = 0;
745
746 bool is_blocked_8_format = false;
747 format_tag_t src_tag = format_tag::undef;
748 jit_memory_tag_kind_t tag_kind = jit_memory_tag_kind_t::undef;
749 alg_kind_t alg = alg_kind::undef;
750
751 cpu_isa_t isa = isa_any;
752
753 post_ops_t post_ops = post_ops_t();
754 bool with_postops = false;
755 bool with_eltwise = false;
756 bool with_binary = false;
757 bool with_sum = false;
758 std::queue<float> sum_scales;
759 };
760
761 struct jit_resampling_call_s {
762 size_t batch_of_sp_points_to_process = 0;
763
764 const void *src = nullptr;
765 const void *dst = nullptr;
766 const void *indices = nullptr;
767 const void *weights = nullptr;
768 const void *post_ops_binary_rhs_arg_vec = nullptr;
769 const void *dst_orig = nullptr;
770
771 size_t c_offset = 0;
772
773 size_t src_offset_top = 0;
774 size_t src_offset_bottom = 0;
775 size_t src_offset_front = 0;
776 size_t src_offset_back = 0;
777
778 float weight_top = 0.0f;
779 float weight_bottom = 0.0f;
780 float weight_front = 0.0f;
781 float weight_back = 0.0f;
782 };
783
784 struct jit_brdgmm_conv_conf_t {
785
786 int nthr;
787 int mb, ngroups, ic, oc;
788 int ih, iw, oh, ow;
789 int l_pad, r_pad, t_pad, b_pad;
790 int kh, kw;
791 int stride_h, stride_w;
792 int nb_ch, ch_block, chb_tail;
793 int nb_ch_blocking;
794 int ow_block, ow_tail, nb_ow;
795 // idx of jit kernel when mutiple jit kernels are used in a primitive.
796 int chb_tail_idx, ow_tail_idx, nb_ch_blocking_idx;
797 int adjusted_batch_size;
798
799 bool with_bias;
800 bool with_post_ops;
801 bool is_oc_scale;
802
803 data_type_t src_dt;
804 data_type_t wei_dt;
805 data_type_t bia_dt;
806 data_type_t dst_dt;
807
808 brgemm_batch_kind_t batch_kind;
809
810 size_t src_dsz;
811 size_t wei_dsz;
812 size_t bia_dsz;
813 size_t dst_dsz;
814
815 cpu_isa_t isa;
816 };
817
818 enum conv_brgemm_loop_order_t {
819 loop_ndhwgc,
820 loop_ngcdhw,
821 };
822
823 enum conv_brgemm_exec_type_t {
824 exec_undefined = 0,
825 exec_base,
826 exec_trans,
827 exec_vpad,
828 };
829
830 struct jit_brgemm_conv_conf_t {
831 cpu_isa_t isa;
832 prop_kind_t prop_kind;
833 conv_version_t ver;
834 conv_brgemm_loop_order_t loop_order;
835 conv_harness_t harness;
836 int simd_w, amx_w, amx_h;
837 int ndims;
838 int mb;
839 int ngroups, ic, oc, oc_without_padding, ic_without_padding;
840
841 int od_block, oh_block, nb_od,
842 nb_oh; // blocking - included in parallelization
843 dim_t inp_buffer_size, inp_buffer_mask_size;
844 conv_brgemm_exec_type_t exec_type;
845
846 int id, ih, iw, od, oh, ow, os, idp, ihp, iwp, icp;
847 int f_pad, l_pad, t_pad;
848 int back_pad, r_pad, b_pad;
849 int kd, kh, kw;
850 int ext_kd, ext_kh, ext_kw;
851 int kd_block, kh_block, kw_block, kd_block_pad, kh_block_pad, kw_block_pad;
852 int stride_d, stride_h, stride_w;
853 int dilate_d, dilate_h, dilate_w;
854 format_tag_t src_tag, wei_tag, dst_tag; // temporary workaround
855 bool with_bias;
856 bool with_sum;
857 bool with_eltwise;
858 bool with_binary;
859
860 bool is_fused_conv;
861 bool is_os_blocking;
862 bool is_rtus;
863 int nb_ic, ic_block;
864 int nb_oc, oc_block;
865 int nb_iw, iw_block;
866 int nb_ow, ow_block, ow_tail;
867 int nb_os, os_block;
868 int nb_oc_blocking;
869 int nb_ic_blocking;
870 int nb_os_blocking;
871
872 data_type_t src_dt;
873 data_type_t dst_dt;
874 data_type_t wei_dt;
875 data_type_t acc_dt;
876 data_type_t bia_dt;
877 size_t src_dsz;
878 size_t wei_dsz;
879 size_t dst_dsz;
880 size_t acc_dsz;
881 size_t bia_dsz;
882
883 bool use_buffer;
884 dim_t buffer_size;
885
886 int is_oc_scale;
887
888 int LDA, LDB, LDC, LDD;
889 int M, N, K, M_tail, N_tail, K_tail;
890 // M for brgemm kernel. For use_store_mask it is usually greater than M (M_tail). Otherwise it is equal to M (M_tail)
891 int brgM, brgM_tail;
892 int gemm_batch_size, adjusted_batch_size;
893 brgemm_batch_kind_t brg_type;
894 // strides for brg_type == brgemm_strd
895 dim_t brg_stride_a, brg_stride_b;
896 int nthr;
897
898 int max_batch;
899 int max_vpad;
900
901 bool wei_plain;
902 bool is_ic_padded;
903 int kw_sets, kh_sets;
904 bool copy_block_only;
905 bool amx_tile_load_xx;
906 int use_M_mask;
907 int oskip;
908 bool brgemm_bd_loop_innermost;
909
910 bool use_uker;
911 bool use_interleave_stores;
912 bool is_1x1;
913 };
914
915 struct jit_shuffle_conf_t {
916 unsigned ndims = 0;
917
918 unsigned mb = 0, c = 0, d = 0, h = 0, w = 0, sp = 0;
919
920 unsigned stride_mb = 0;
921 unsigned blk_size = 0;
922 unsigned group_size = 0;
923 unsigned axis = 0;
924 unsigned axis_size = 0;
925 unsigned simd_tail = 0;
926 unsigned simd_w = 0;
927
928 jit_memory_tag_kind_t tag_kind = jit_memory_tag_kind_t::undef;
929 data_type_t data_type = data_type::undef;
930 size_t dt_size = 0;
931 unsigned el_size_of_indices = 0;
932 dim_t c_split_size = 0;
933 dim_t sp_split_size = 0;
934
935 cpu_isa_t isa = isa_any;
936 };
937
938 struct jit_shuffle_call_s {
939 const void *src = nullptr;
940 void *dst = nullptr;
941 const void *input_off_ptr = nullptr;
942
943 dim_t cb_loop_size
944 = 0; // number of loop iterations over corresponding C batches
945 bool is_padded_block = false;
946 };
947
948 enum class binary_op_t : unsigned { none, c_blocked, n_spatial_c, n_c_spatial };
949
950 enum class binary_bcast_t : unsigned {
951 none, // tensor operation
952 scalar,
953 per_batch,
954 per_c,
955 per_w
956 };
957
958 struct jit_binary_conf_t {
959 binary_op_t op_type = binary_op_t::none;
960 binary_bcast_t bcast_type = binary_bcast_t::none;
961 bool do_scale_src0 = false;
962 bool do_scale_src1 = false;
963 bool do_sum = false;
964 bool with_eltwise = false;
965 bool with_binary = false;
966 bool with_postops = false;
967 float sum_scale = 0.f;
968 bool use_stride_src1 = false;
969 bool broadcast_src1_value = false;
970 bool use_stride_rhs_postops = false;
971 bool postops_per_oc_broadcast_exists = false;
972 bool is_i8 = false;
973 bool is_bf16 = false;
974 bool is_src_different_layouts = false;
975 dim_t outer_dims = 1;
976 int src1_stride = 1;
977 int not_bcasted_sp_dims = 0;
978
979 data_type_t src0_type = data_type::undef;
980 data_type_t src1_type = data_type::undef;
981 data_type_t dst_type = data_type::undef;
982 };
983
984 struct jit_binary_call_s {
985 // keep all sizes at 8 bytes -- jit code expects this
986 const void *src0, *src1, *dst, *indices;
987 const float *scales_src0, *scales_src1;
988 size_t spat_offt_count;
989 const void *post_ops_binary_rhs_arg_vec;
990 size_t src1_stride_range;
991 const void *dst_orig;
992 };
993
994 struct jit_reduction_conf_t {
995 data_type_t src_type = data_type::undef;
996 data_type_t dst_type = data_type::undef;
997 data_type_t acc_type = data_type::undef;
998
999 std::size_t src_dt_size = 0;
1000 std::size_t dst_dt_size = 0;
1001 std::size_t acc_dt_size = 0;
1002
1003 alg_kind_t alg = alg_kind::undef;
1004 cpu_isa_t isa = isa_any;
1005
1006 dim_t idle_size = 0;
1007 dim_t reduce_size = 0;
1008
1009 bool is_saturation_needed = false;
1010 };
1011
1012 struct jit_reduction_call_s {
1013 const void *src = nullptr;
1014 void *dst = nullptr;
1015 };
1016
1017 } // namespace x64
1018 } // namespace cpu
1019 } // namespace impl
1020 } // namespace dnnl
1021
1022 #endif
1023