1 /*******************************************************************************
2 * Copyright 2018-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #ifndef CPU_RNN_RNN_UTILS_HPP
18 #define CPU_RNN_RNN_UTILS_HPP
19 
20 #include <type_traits>
21 
22 #include "common/c_types_map.hpp"
23 #include "common/memory_desc_wrapper.hpp"
24 #include "common/primitive.hpp"
25 #include "common/utils.hpp"
26 
27 #include "cpu/platform.hpp"
28 
29 #include "cpu/gemm/gemm_pack.hpp"
30 
31 #if DNNL_X64
32 #include "cpu/x64/cpu_isa_traits.hpp"
33 #endif
34 
35 #define rnn_postgemm_sig(f) \
36     void f(const rnn_utils::rnn_conf_t &rnn, \
37             rnn_utils::cell_position_t cell_position, gates_t *ws_gates_, \
38             scratch_t *scratch_gates_, dst_layer_t *dst_layer_, \
39             float *dst_iter_c_, const src_iter_t *src_iter_, \
40             const float *src_iter_c_, gemm_acc_t *diff_src_layer_, \
41             gemm_acc_t *diff_src_iter_, gemm_acc_t *diff_src_iter_c_, \
42             gemm_acc_t *diff_dst_layer_, gemm_acc_t *diff_dst_iter_, \
43             gemm_acc_t *diff_dst_iter_c_, const float *weights_peephole_, \
44             float *bias_, gates_t *ws_grid_, scratch_t *scratch_cell_, \
45             dst_iter_t *dst_iter_, float *weights_scales_, int block_step) \
46             const
47 
48 #if DNNL_X64
49 #define rnn_cell_execution_sig(f) \
50     dnnl_status_t f(const exec_ctx_t &ctx, const rnn_utils::rnn_conf_t &rnn, \
51             rnn_utils::cell_position_t cell_position, dst_layer_t *dst_layer_, \
52             float *dst_iter_c_, gemm_acc_t *diff_src_layer_, \
53             gemm_acc_t *diff_src_iter_, gemm_acc_t *diff_src_iter_c_, \
54             weights_t **w_layer_, weights_t **w_iter_, \
55             weights_t **w_projection_, const float *weights_peephole_, \
56             const float *w_proj_comp, float **bias_, \
57             const src_layer_t *src_layer_, const src_iter_t *src_iter_, \
58             const float *src_iter_c_, gemm_acc_t *diff_dst_layer_, \
59             gemm_acc_t *diff_dst_iter_, gemm_acc_t *diff_dst_iter_c_, \
60             gemm_acc_t *diff_w_layer_, gemm_acc_t *diff_w_iter_, \
61             float *diff_weights_projection_, float *diff_weights_peephole_, \
62             float *diff_bias_, gates_t *ws_gates_, scratch_t *scratch_gates_, \
63             ht_t *proj_ht_, gemm_acc_t *scratch_diff_ht_, gates_t *ws_grid_, \
64             scratch_t *scratch_cell_, scratch_t *scratch_gates_blocked_, \
65             scratch_t *scratch_src_layer_, scratch_t *scratch_src_iter_, \
66             dst_iter_t *dst_iter_, gemm_acc_t *amx_scratchpad, \
67             x64::brgemm_batch_element_t *addr_batch_global) const
68 
69 #define rnn_grid_execution_sig(f) \
70     dnnl_status_t f(const exec_ctx_t &ctx, const rnn_utils::rnn_conf_t &rnn, \
71             weights_t **weights_layer_, weights_t **weights_iter_, \
72             weights_t **weights_projection_, const float *weights_peephole_, \
73             const float *w_proj_comp, float **bias_, \
74             const src_layer_t *src_layer_, const src_iter_t *src_iter_, \
75             const float *src_iter_c_, dst_layer_t *dst_layer_, \
76             dst_iter_t *dst_iter_, float *dst_iter_c_, \
77             src_layer_t *ws_states_layer_, src_iter_t *ws_states_iter_, \
78             float *ws_states_iter_c_, gemm_acc_t *ws_diff_states_layer_, \
79             gemm_acc_t *ws_diff_states_iter_, \
80             gemm_acc_t *ws_diff_states_iter_c_, gates_t *ws_gates_, \
81             ht_t *ws_ht_, gates_t *ws_grid_, scratch_t *scratch_gates_, \
82             ht_t *scratch_ht_, gemm_acc_t *scratch_diff_ht_, \
83             scratch_t *scratch_cell_, scratch_t *scratch_gates_blocked_, \
84             scratch_t *scratch_src_layer_, scratch_t *scratch_src_iter_, \
85             gemm_acc_t *diff_weights_layer_, gemm_acc_t *diff_weights_iter_, \
86             float *diff_weights_projection_, float *diff_weights_peephole_, \
87             float *diff_bias_, gemm_acc_t *amx_scratchpad, \
88             x64::brgemm_batch_element_t *addr_batch_global) const
89 #else
90 #define rnn_cell_execution_sig(f) \
91     dnnl_status_t f(const rnn_utils::rnn_conf_t &rnn, \
92             rnn_utils::cell_position_t cell_position, dst_layer_t *dst_layer_, \
93             float *dst_iter_c_, gemm_acc_t *diff_src_layer_, \
94             gemm_acc_t *diff_src_iter_, gemm_acc_t *diff_src_iter_c_, \
95             weights_t **w_layer_, weights_t **w_iter_, \
96             weights_t **w_projection_, const float *weights_peephole_, \
97             const float *w_proj_comp, float **bias_, \
98             const src_layer_t *src_layer_, const src_iter_t *src_iter_, \
99             const float *src_iter_c_, gemm_acc_t *diff_dst_layer_, \
100             gemm_acc_t *diff_dst_iter_, gemm_acc_t *diff_dst_iter_c_, \
101             gemm_acc_t *diff_w_layer_, gemm_acc_t *diff_w_iter_, \
102             float *diff_weights_projection_, float *diff_weights_peephole_, \
103             float *diff_bias_, gates_t *ws_gates_, scratch_t *scratch_gates_, \
104             ht_t *proj_ht_, gemm_acc_t *scratch_diff_ht_, gates_t *ws_grid_, \
105             scratch_t *scratch_cell_, dst_iter_t *dst_iter_, \
106             gemm_acc_t *amx_scratchpad) const
107 
108 #define rnn_grid_execution_sig(f) \
109     dnnl_status_t f(const rnn_utils::rnn_conf_t &rnn, \
110             weights_t **weights_layer_, weights_t **weights_iter_, \
111             weights_t **weights_projection_, const float *weights_peephole_, \
112             const float *w_proj_comp, float **bias_, \
113             const src_layer_t *src_layer_, const src_iter_t *src_iter_, \
114             const float *src_iter_c_, dst_layer_t *dst_layer_, \
115             dst_iter_t *dst_iter_, float *dst_iter_c_, \
116             src_layer_t *ws_states_layer_, src_iter_t *ws_states_iter_, \
117             float *ws_states_iter_c_, gemm_acc_t *ws_diff_states_layer_, \
118             gemm_acc_t *ws_diff_states_iter_, \
119             gemm_acc_t *ws_diff_states_iter_c_, gates_t *ws_gates_, \
120             ht_t *ws_ht_, gates_t *ws_grid_, scratch_t *scratch_gates_, \
121             ht_t *scratch_ht_, gemm_acc_t *scratch_diff_ht_, \
122             scratch_t *scratch_cell_, gemm_acc_t *diff_weights_layer_, \
123             gemm_acc_t *diff_weights_iter_, float *diff_weights_projection_, \
124             float *diff_weights_peephole_, float *diff_bias_, \
125             gemm_acc_t *amx_scratchpad) const
126 #endif
127 
128 #define rnn_gemm_sig(f) \
129     dnnl_status_t f(const char transA, const char transB, dim_t m, dim_t n, \
130             dim_t k, const float alpha, const weights_t *a_, const dim_t ldA, \
131             const gemm_data_t *b_, const dim_t ldB, const float beta, \
132             gemm_acc_t *c_, const dim_t ldC) const
133 
134 #define rnn_bias_prepare_sig(f) \
135     void f(const rnn_utils::rnn_conf_t &rnn, float **bias_, const float *b_, \
136             float *scratch_bias_) const
137 
138 #define rnn_bias_finalize_sig(f) \
139     void f(const rnn_utils::rnn_conf_t &rnn, float *scratch_bias_, \
140             const float *w_iter_comp, const float *w_layer_comp) const
141 
142 #define rnn_weights_assign_sig(f) \
143     void f(const rnn_utils::rnn_conf_t &rnn, const memory_desc_t *md, \
144             int n_parts, const int *gates_per_part, weights_t **weights_, \
145             const weights_t *w_) const
146 
147 namespace dnnl {
148 namespace impl {
149 namespace cpu {
150 
151 namespace rnn_utils {
152 
153 enum execution_direction_t {
154     l2r,
155     r2l,
156     bi_concat,
157     bi_sum,
158 };
159 
160 enum cell_position_t {
161     middle_cell = 0x0,
162     first_layer = 0x1,
163     first_iter = 0x2,
164     last_layer = 0x4,
165     last_iter = 0x8,
166     c_state_first_iter = 0x10,
167     c_state_last_iter = 0x20
168 };
169 
170 enum class weights_type_t {
171     layer,
172     iter,
173     projection,
174     peephole,
175 };
176 
operator |=(cell_position_t & lhs,cell_position_t rhs)177 inline cell_position_t &operator|=(cell_position_t &lhs, cell_position_t rhs) {
178     lhs = static_cast<cell_position_t>(
179             static_cast<unsigned>(lhs) | static_cast<unsigned>(rhs));
180     return lhs;
181 }
182 
operator |(cell_position_t lhs,cell_position_t rhs)183 inline cell_position_t operator|(cell_position_t lhs, cell_position_t rhs) {
184     return static_cast<cell_position_t>(
185             static_cast<unsigned>(lhs) | static_cast<unsigned>(rhs));
186 }
187 
188 enum data_type_conf_t {
189     all_f32,
190     all_bf16,
191     u8u8u8f32,
192     f32u8f32f32,
193     u8u8u8u8,
194     f32u8f32u8,
195     s8s8s8f32,
196     f32s8f32f32,
197     s8s8s8s8,
198     f32s8f32s8
199 };
200 
201 struct diff_src_brgemm_conf_t {
202     dim_t M = 0, N = 0, K = 0;
203 
204     dim_t n_block = 0, N_blocks = 0, n_tail = 0;
205     dim_t m_block = 0, M_blocks = 0;
206 
207     dim_t K_blocks = 0, k_block = 0, k_tail = 0;
208     dim_t Kpadded = 0;
209 
210     dim_t N_iter = 0, N_layer = 0;
211     dim_t N_layer_blocks = 0, n_layer_tail = 0;
212     dim_t N_iter_blocks = 0, n_iter_tail = 0;
213     dim_t LDA = 0, LDB = 0, LDC = 0;
214 
215 #if DNNL_X64
216     x64::cpu_isa_t isa = x64::isa_any;
217 #endif
218 };
219 
220 struct diff_wei_brgemm_conf_t {
221     dim_t M = 0, M_layer = 0, M_iter = 0, N = 0, K = 0;
222 
223     dim_t n_block = 0, N_blocks = 0, n_tail = 0;
224     dim_t m_block = 0, M_blocks = 0;
225     dim_t K_blocks = 0, k_block = 0, k_tail = 0;
226     dim_t Kpadded = 0;
227 
228     dim_t LDA_layer = 0, LDA_iter = 0, LDB = 0, LDC_iter = 0, LDC_layer = 0;
229 
230     bool global_transpose = false;
231 
232 #if DNNL_X64
233     x64::cpu_isa_t isa = x64::isa_any;
234 #endif
235 };
236 
237 struct rnn_conf_t {
238     execution_direction_t exec_dir;
239     data_type_conf_t dt_conf;
240     int n_layer = 0, n_iter = 0, n_dir = 0, n_gates = 0, n_states = 0;
241     int mb = 0;
242     int slc = 0, sic = 0, dhc = 0, dic = 0, dlc = 0;
243     //int gates_ld, gates_nld, gates_ws_ld;
244 
245     int n_parts_weights_layer = 0;
246     int parts_weights_layer[DNNL_RNN_MAX_N_PARTS];
247     size_t part_weights_layer_pack_size[DNNL_RNN_MAX_N_PARTS];
248 
249     int n_parts_weights_iter = 0;
250     int parts_weights_iter[DNNL_RNN_MAX_N_PARTS];
251     size_t part_weights_iter_pack_size[DNNL_RNN_MAX_N_PARTS];
252 
253     int n_parts_weights_projection = 0;
254     int parts_weights_projection[DNNL_RNN_MAX_N_PARTS];
255     size_t part_weights_projection_pack_size[DNNL_RNN_MAX_N_PARTS];
256 
257     int n_bias = 0, n_parts_bias = 0, parts_bias[DNNL_RNN_MAX_N_PARTS];
258 
259     /* Size of packed data in bytes */
260     size_t weights_layer_comp_offset = 0, weights_layer_pack_size = 0;
261     size_t weights_iter_comp_offset = 0, weights_iter_pack_size = 0;
262     size_t weights_projection_comp_offset = 0, weights_projection_pack_size = 0;
263 
264     bool copy_bias = 0;
265     int weights_layer_ld = 0, weights_layer_nld = 0;
266     int diff_weights_layer_ld = 0, diff_weights_layer_nld = 0;
267     int weights_iter_ld = 0, weights_iter_nld = 0;
268     int diff_weights_iter_ld = 0, diff_weights_iter_nld = 0;
269     int weights_projection_ld = 0, weights_projection_nld = 0;
270     int diff_weights_projection_ld = 0, diff_weights_projection_nld = 0;
271 
272     int proj_ht_ld = 0, proj_ht_nld = 0;
273 
274     int ws_gates_ld = 0, ws_gates_nld = 0;
275     int ws_ht_ld = 0, ws_ht_nld = 0;
276     int ws_states_layer_ld = 0, ws_states_layer_nld = 0;
277     int ws_states_iter_ld = 0, ws_states_iter_nld = 0;
278     int ws_states_iter_c_ld = 0, ws_states_iter_c_nld = 0;
279     int ws_diff_states_layer_ld = 0, ws_diff_states_layer_nld = 0;
280     int ws_diff_states_iter_ld = 0, ws_diff_states_iter_nld = 0;
281     int ws_diff_states_iter_c_ld = 0, ws_diff_states_iter_c_nld = 0;
282 
283     int scratch_gates_ld = 0, scratch_gates_nld = 0;
284     int scratch_ht_ld = 0, scratch_ht_nld = 0;
285     int scratch_diff_ht_ld = 0, scratch_diff_ht_nld = 0;
286 
287     int src_layer_ld_ = 0, src_layer_nld_ = 0;
288     int src_iter_ld_ = 0, src_iter_nld_ = 0;
289     int src_iter_c_ld_ = 0, src_iter_c_nld_ = 0;
290     int dst_layer_ld_ = 0, dst_layer_nld_ = 0;
291     int dst_iter_ld_ = 0, dst_iter_nld_ = 0;
292     int dst_iter_c_ld_ = 0, dst_iter_c_nld_ = 0;
293 
294     int weights_iter_compensation_size = 0, weights_layer_compensation_size = 0;
295     bool is_fwd = 0, is_training = 0, is_lbr = 0, is_lstm_peephole = 0,
296          is_lstm_projection = 0;
297     bool use_workspace = 0;
298 
299     // Size of workspace for each tensor in bytes
300     // Notes:
301     // 1. For non-LSTMP ws_states_iter_size == ws_states_layer_size. The corresponding
302     //    pointers should point to the same places.
303     size_t ws_gates_size = 0;
304     size_t ws_ht_size = 0;
305     size_t ws_states_layer_size = 0;
306     size_t ws_states_iter_size = 0;
307     size_t ws_states_iter_c_size = 0;
308     size_t ws_diff_states_layer_size = 0;
309     size_t ws_diff_states_iter_size = 0;
310     size_t ws_diff_states_iter_c_size = 0;
311     size_t scratch_gates_size = 0;
312 
313     size_t scratch_gates_blocked_size = 0;
314     size_t scratch_gates_blocked_nested_reorder_size = 0;
315     size_t scratch_src_layer_size = 0;
316     size_t scratch_src_layer_nested_reorder_size = 0;
317     size_t scratch_src_iter_size = 0;
318     size_t scratch_src_iter_nested_reorder_size = 0;
319 
320     size_t scratch_ht_size = 0;
321     size_t scratch_diff_ht_size = 0;
322     size_t scratch_cell_size = 0;
323     size_t ws_grid_comp_size = 0;
324     size_t ws_per_cell = 0;
325     size_t ws_bias_size = 0;
326 
327     bool merge_gemm_iter = false, merge_gemm_layer = false,
328          force_nocopy = false, use_layer_packed_gemm = false,
329          use_iter_packed_gemm = false, use_projection_packed_gemm = false;
330     int n_iter_scratch_gates = 0;
331 
is_int8dnnl::impl::cpu::rnn_utils::rnn_conf_t332     inline bool is_int8() const {
333         return is_signed_int8() || is_unsigned_int8();
334     }
is_signed_int8dnnl::impl::cpu::rnn_utils::rnn_conf_t335     inline bool is_signed_int8() const {
336         return utils::one_of(
337                 dt_conf, s8s8s8f32, f32s8f32f32, s8s8s8s8, f32s8f32s8);
338     }
is_unsigned_int8dnnl::impl::cpu::rnn_utils::rnn_conf_t339     inline bool is_unsigned_int8() const {
340         return utils::one_of(
341                 dt_conf, u8u8u8f32, f32u8f32f32, u8u8u8u8, f32u8f32u8);
342     }
is_int8_amxdnnl::impl::cpu::rnn_utils::rnn_conf_t343     inline bool is_int8_amx() const {
344 #if DNNL_X64
345         return brgemm_isa == x64::avx512_core_bf16_amx_int8 && is_int8();
346 #else
347         return false;
348 #endif
349     }
is_bf16dnnl::impl::cpu::rnn_utils::rnn_conf_t350     inline bool is_bf16() const { return dt_conf == all_bf16; }
is_bf16_amxdnnl::impl::cpu::rnn_utils::rnn_conf_t351     inline bool is_bf16_amx() const {
352 #if DNNL_X64
353         return brgemm_isa == x64::avx512_core_bf16_amx_bf16 && is_bf16();
354 #else
355         return false;
356 #endif
357     }
is_f32dnnl::impl::cpu::rnn_utils::rnn_conf_t358     inline bool is_f32() const { return dt_conf == all_f32; }
359 
skip_src_layer_copydnnl::impl::cpu::rnn_utils::rnn_conf_t360     inline bool skip_src_layer_copy() const {
361         // Note: this currently always returns true
362         return (exec_dir == l2r)
363                 && utils::one_of(dt_conf, s8s8s8f32, f32s8f32f32, s8s8s8s8,
364                         f32s8f32s8, u8u8u8u8, u8u8u8f32, f32u8f32u8,
365                         f32u8f32f32, all_f32, all_bf16);
366     }
skip_src_iter_copydnnl::impl::cpu::rnn_utils::rnn_conf_t367     inline bool skip_src_iter_copy() const {
368         return (exec_dir == l2r) && (src_iter_ld_ > 0)
369                 && utils::one_of(dt_conf, s8s8s8s8, s8s8s8f32, u8u8u8u8,
370                         u8u8u8f32, all_f32, all_bf16);
371     }
skip_dst_layer_copydnnl::impl::cpu::rnn_utils::rnn_conf_t372     inline bool skip_dst_layer_copy() const {
373         return (exec_dir == l2r)
374                 && utils::one_of(dt_conf, s8s8s8s8, f32s8f32s8, u8u8u8u8,
375                         f32u8f32u8, all_f32, all_bf16);
376     }
skip_dst_iter_copydnnl::impl::cpu::rnn_utils::rnn_conf_t377     inline bool skip_dst_iter_copy() const {
378         return (exec_dir == l2r) && (dst_iter_ld_ > 0)
379                 && utils::one_of(dt_conf, s8s8s8s8, s8s8s8f32, u8u8u8u8,
380                         u8u8u8f32, all_f32, all_bf16);
381     }
382 
src_layer_lddnnl::impl::cpu::rnn_utils::rnn_conf_t383     inline dim_t src_layer_ld(cell_position_t cell_position) const {
384         return (cell_position & first_layer) && skip_src_layer_copy()
385                 ? src_layer_ld_
386                 : (cell_position & last_iter) && skip_dst_iter_copy()
387                         ? dst_iter_ld_
388                         : ws_states_layer_ld;
389     }
390 
src_iter_lddnnl::impl::cpu::rnn_utils::rnn_conf_t391     inline dim_t src_iter_ld(cell_position_t cell_position) const {
392         return (cell_position & first_iter) && skip_src_iter_copy()
393                 ? src_iter_ld_
394                 : ((cell_position & last_layer) && skip_dst_layer_copy()
395                                         && !(cell_position & first_iter)
396                                 ? dst_layer_ld_
397                                 : ws_states_iter_ld);
398     }
399 
layer_brgemm_descdnnl::impl::cpu::rnn_utils::rnn_conf_t400     inline dim_t layer_brgemm_desc(cell_position_t cell_position) const {
401         return ((cell_position & first_layer) && skip_src_layer_copy())
402                 ? 0
403                 : ((cell_position & last_iter) && skip_dst_iter_copy()) ? 1 : 2;
404     }
405 
iter_brgemm_descdnnl::impl::cpu::rnn_utils::rnn_conf_t406     inline dim_t iter_brgemm_desc(cell_position_t cell_position) const {
407         return ((cell_position & first_iter) && skip_src_iter_copy())
408                 ? 0
409                 : ((cell_position & last_layer) && skip_dst_layer_copy()
410                           && !(cell_position & first_iter))
411                         ? 1
412                         : 2;
413     }
414 
src_iter_c_lddnnl::impl::cpu::rnn_utils::rnn_conf_t415     inline dim_t src_iter_c_ld(cell_position_t cell_position) const {
416         return (cell_position & c_state_first_iter) ? src_iter_c_ld_
417                                                     : ws_states_iter_c_ld;
418     }
419 
dst_layer_lddnnl::impl::cpu::rnn_utils::rnn_conf_t420     inline dim_t dst_layer_ld(
421             cell_position_t cell_position, bool after_proj = false) const {
422         // We use scratch_ht and not dst_layer for lstmp
423         if (is_lstm_projection && !after_proj) return scratch_ht_ld;
424 
425         return (cell_position & last_layer) && skip_dst_layer_copy()
426                 ? dst_layer_ld_
427                 : (cell_position & last_iter) && skip_dst_iter_copy()
428                         ? dst_iter_ld_
429                         : ws_states_layer_ld;
430     }
431 
dst_brgemm_descdnnl::impl::cpu::rnn_utils::rnn_conf_t432     inline dim_t dst_brgemm_desc(
433             cell_position_t cell_position, bool after_proj = false) const {
434         // We use scratch_ht and not dst_layer for lstmp
435         if (is_lstm_projection && !after_proj) return 0;
436 
437         return (cell_position & last_layer) && skip_dst_layer_copy()
438                 ? 1
439                 : (cell_position & last_iter) && skip_dst_iter_copy() ? 2 : 3;
440     }
441 
dst_iter_lddnnl::impl::cpu::rnn_utils::rnn_conf_t442     inline dim_t dst_iter_ld(cell_position_t cell_position) const {
443         return (cell_position & last_iter) && skip_dst_iter_copy()
444                 ? dst_iter_ld_
445                 : ws_states_iter_ld;
446     }
447 
dst_iter_c_lddnnl::impl::cpu::rnn_utils::rnn_conf_t448     inline dim_t dst_iter_c_ld(cell_position_t cell_position) const {
449         return (cell_position & c_state_last_iter) ? dst_iter_c_ld_
450                                                    : ws_states_iter_c_ld;
451     }
452 
453     // // when skipping copy, the output ld can be states_ws_ld,
454     // // dst_iter_ld or dst_layer_ld depending on the cell position
455     // inline dim_t dst_ld(cell_position_t cell_position) const {
456     //     return (cell_position & last_layer) ? dst_layer_ld(cell_position)
457     //                                         : dst_iter_ld(cell_position);
458     // }
dst_copy_lddnnl::impl::cpu::rnn_utils::rnn_conf_t459     inline dim_t dst_copy_ld(cell_position_t cell_position) const {
460         return dst_iter_ld(cell_position);
461     }
462 
need_gemm_layerdnnl::impl::cpu::rnn_utils::rnn_conf_t463     inline bool need_gemm_layer(cell_position_t cell_position) const {
464         // In case of merge_gemm_layer we might still need a layer gemm if we store
465         // the states of the last iteration in the destination memory. The
466         // exception of this rule is the first layer though, in which case all
467         // states are kept in user's src_layer, hence making full merged gemm
468         // possible.
469         return IMPLICATION(merge_gemm_layer,
470                 skip_dst_iter_copy() && (cell_position & last_iter)
471                         && !(cell_position & first_layer));
472     }
473     bool is_brgemm;
474 
475     diff_src_brgemm_conf_t diff_src_brgemm;
476     diff_wei_brgemm_conf_t diff_wei_brgemm;
477 
478     dim_t M, N, K1, K2;
479 
480     dim_t LDB1, LDB2;
481     dim_t LDA1[3];
482     dim_t LDA2[3];
483     dim_t LDC;
484 
485     dim_t m_block, M_blocks;
486     dim_t n_block, N_blocks, n_tail;
487 
488     dim_t k2_block, k1_block, k1_tail, k2_tail;
489     dim_t KB1_blocks, KB2_blocks;
490     dim_t K1padded, K2padded;
491 
492     dim_t Kproj, Kprojpadded;
493     dim_t kproj_block, KBproj_blocks, kproj_tail;
494 
495     dim_t Nproj, Nproj_blocks, nproj_tail;
496     dim_t LDAproj, LDBproj, LDCproj[4];
497     int dhc_block_peephole, dhc_tail_peephole, dhc_blocks_peephole;
498 
499     dim_t nthr;
500 #if DNNL_X64
501     x64::cpu_isa_t brgemm_isa;
502 #endif
503     bool unfused_post_gemm;
504 };
505 
506 bool is_ldigo(const memory_desc_wrapper &md);
507 bool is_ldgoi(const memory_desc_wrapper &md);
508 bool is_ldio(const memory_desc_wrapper &md);
509 bool is_ldoi(const memory_desc_wrapper &md);
510 bool is_ldigo_blocked(const memory_desc_wrapper &md);
511 bool is_ldgoi_blocked(const memory_desc_wrapper &md);
512 bool is_ldio_blocked(const memory_desc_wrapper &md);
513 bool is_ldoi_blocked(const memory_desc_wrapper &md);
514 
515 int get_good_ld(int dim, int sizeof_dt);
516 
517 template <typename T>
init_conf(rnn_conf_t & rnn,const rnn_desc_t & rd,const memory_desc_wrapper & src_layer_d,const memory_desc_wrapper & src_iter_d,const memory_desc_wrapper & src_iter_c_d,const memory_desc_wrapper & weights_layer_d,const memory_desc_wrapper & weights_iter_d,const memory_desc_wrapper & weights_projection_d,const memory_desc_wrapper & dst_layer_d,const memory_desc_wrapper & dst_iter_d,const memory_desc_wrapper & dst_iter_c_d)518 bool init_conf(rnn_conf_t &rnn, const rnn_desc_t &rd,
519         const memory_desc_wrapper &src_layer_d,
520         const memory_desc_wrapper &src_iter_d,
521         const memory_desc_wrapper &src_iter_c_d,
522         const memory_desc_wrapper &weights_layer_d,
523         const memory_desc_wrapper &weights_iter_d,
524         const memory_desc_wrapper &weights_projection_d,
525         const memory_desc_wrapper &dst_layer_d,
526         const memory_desc_wrapper &dst_iter_d,
527         const memory_desc_wrapper &dst_iter_c_d) {
528     rnn.is_fwd = utils::one_of(rd.prop_kind, prop_kind::forward_training,
529             prop_kind::forward_inference);
530     rnn.is_training = utils::one_of(
531             rd.prop_kind, prop_kind::forward_training, prop_kind::backward);
532     rnn.is_lbr = rd.cell_kind == dnnl_lbr_gru;
533     rnn.is_lstm_peephole = rd.cell_kind == dnnl_vanilla_lstm
534             && !memory_desc_wrapper(rd.weights_peephole_desc).is_zero();
535     rnn.is_lstm_projection = rd.cell_kind == dnnl_vanilla_lstm
536             && !memory_desc_wrapper(rd.weights_projection_desc).is_zero();
537 
538     switch (rd.direction) {
539         case dnnl_unidirectional_left2right: rnn.exec_dir = l2r; break;
540         case dnnl_unidirectional_right2left: rnn.exec_dir = r2l; break;
541         case dnnl_bidirectional_concat: rnn.exec_dir = bi_concat; break;
542         case dnnl_bidirectional_sum: rnn.exec_dir = bi_sum; break;
543         default: break;
544     }
545 
546     if (utils::everyone_is(data_type::f32, src_layer_d.data_type(),
547                 dst_layer_d.data_type(), weights_layer_d.data_type()))
548         rnn.dt_conf = all_f32;
549     else if (utils::everyone_is(data_type::bf16, src_layer_d.data_type(),
550                      dst_layer_d.data_type(), weights_layer_d.data_type())) {
551         if (!platform::has_data_type_support(data_type::bf16)) return false;
552         rnn.dt_conf = all_bf16;
553     } else if (dst_layer_d.data_type() == data_type::u8) {
554         if (IMPLICATION(
555                     src_iter_d.md_, src_iter_d.data_type() == data_type::u8))
556             rnn.dt_conf = u8u8u8u8;
557         else
558             rnn.dt_conf = f32u8f32u8;
559     } else if (dst_layer_d.data_type() == data_type::s8) {
560         if (IMPLICATION(
561                     src_iter_d.md_, src_iter_d.data_type() == data_type::s8))
562             rnn.dt_conf = s8s8s8s8;
563         else
564             rnn.dt_conf = f32s8f32s8;
565 
566     } else if (dst_layer_d.data_type() == data_type::f32) {
567         if (IMPLICATION(
568                     src_iter_d.md_, src_iter_d.data_type() == data_type::u8))
569             rnn.dt_conf = u8u8u8f32;
570         else if (IMPLICATION(src_iter_d.md_,
571                          src_iter_d.data_type() == data_type::s8))
572             rnn.dt_conf = s8s8s8f32;
573         else if (IMPLICATION(src_layer_d.md_,
574                          src_layer_d.data_type() == data_type::s8))
575             rnn.dt_conf = f32s8f32f32;
576         else
577             rnn.dt_conf = f32u8f32f32;
578     }
579 
580     // Set problem members defining problem sizes
581     rnn.n_layer = weights_layer_d.dims()[0];
582     rnn.n_iter = src_layer_d.dims()[0];
583     rnn.n_dir = weights_layer_d.dims()[1];
584     rnn.n_gates = weights_layer_d.dims()[3];
585     rnn.n_states = rd.cell_kind == dnnl_vanilla_lstm ? 2 : 1;
586     rnn.n_bias = rnn.n_gates + rnn.is_lbr;
587     rnn.mb = src_layer_d.dims()[1];
588     rnn.sic = weights_iter_d.dims()[2];
589     rnn.slc = weights_layer_d.dims()[2];
590     rnn.dhc = weights_layer_d.dims()[4];
591     rnn.dlc = rnn.is_lstm_projection ? weights_projection_d.dims()[3] : rnn.dhc;
592     // All supported cells have dic == dlc
593     rnn.dic = rnn.dlc;
594 
595     // set members with user memories leading dimensions
596     // Assumption: weights datatype size is the same as state datatype size
597     assert(types::data_type_size(weights_layer_d.data_type())
598             == types::data_type_size(src_layer_d.data_type()));
599 
600     // set workspace leading dimensions (and non leading-dimensions)
601 
602     // the ws and scratch proj_ht need to match as we use them interchangeably
603     assert(IMPLICATION(rnn.is_lstm_projection,
604             sizeof(typename T::ht_t) == sizeof(typename T::dst_iter_t)));
605     rnn.proj_ht_nld = rnn.mb;
606     rnn.proj_ht_ld = get_good_ld(rnn.dhc, sizeof(typename T::ht_t));
607 
608     rnn.ws_gates_nld = rnn.mb;
609     rnn.ws_gates_ld
610             = get_good_ld(rnn.dhc * rnn.n_gates, sizeof(typename T::gates_t));
611     rnn.ws_ht_nld = rnn.proj_ht_nld;
612     rnn.ws_ht_ld = rnn.proj_ht_ld;
613 
614     rnn.ws_states_layer_nld = rnn.mb;
615     static_assert(std::is_same<typename T::src_layer_t,
616                           typename T::src_iter_t>::value,
617             "src_layer_t and src_iter_t must be the same");
618     rnn.ws_states_layer_ld
619             = get_good_ld(nstl::max(rnn.sic, nstl::max(rnn.slc, rnn.dlc)),
620                     sizeof(typename T::src_layer_t));
621     // there is no need for al separate ws_states_iter for now as all
622     // supported cell have dst_iter == dst_layer
623     rnn.ws_states_iter_nld = rnn.ws_states_layer_nld;
624     rnn.ws_states_iter_ld = rnn.ws_states_layer_ld;
625 
626     // we do not need a good ld for iter_c as it is not involved in GEMM
627     rnn.ws_states_iter_c_nld = rnn.mb;
628     rnn.ws_states_iter_c_ld = rnn.dhc;
629 
630     // TODO: be more restrictive on the leading dimensions
631     rnn.ws_diff_states_layer_nld = rnn.mb;
632     rnn.ws_diff_states_layer_ld = get_good_ld(
633             nstl::max(nstl::max(rnn.slc, rnn.dic), nstl::max(rnn.sic, rnn.dhc)),
634             sizeof(typename T::gemm_acc_t));
635 
636     rnn.ws_diff_states_iter_nld = rnn.mb;
637     rnn.ws_diff_states_iter_ld = get_good_ld(
638             nstl::max(nstl::max(rnn.slc, rnn.dic), nstl::max(rnn.sic, rnn.dhc)),
639             sizeof(typename T::gemm_acc_t));
640 
641     rnn.ws_diff_states_iter_c_nld = rnn.mb;
642     rnn.ws_diff_states_iter_c_ld = rnn.dhc;
643 
644     // set scratch (not)leading dimensions
645     // scratch gates is used to store intermediate gates before postgemm operation
646     // temporary: we also use it in lstmp as temporary scratchpad
647     // between projection and downconversion, hence the max with dlc
648     rnn.scratch_gates_nld = rnn.mb;
649     rnn.scratch_gates_ld
650             = get_good_ld(nstl::max(rnn.dlc, rnn.n_gates * rnn.dhc),
651                     sizeof(typename T::scratch_t));
652     rnn.scratch_ht_nld = rnn.proj_ht_nld;
653     rnn.scratch_ht_ld = rnn.proj_ht_ld;
654 
655     rnn.scratch_diff_ht_nld = rnn.mb;
656     rnn.scratch_diff_ht_ld
657             = get_good_ld(rnn.dlc, sizeof(typename T::gemm_acc_t));
658 
659     // Assumption: {src,dst}_layer has tnc layout, {src,dst}_iter has ldnc,
660     rnn.src_layer_ld_ = src_layer_d.blocking_desc().strides[1];
661     rnn.dst_layer_ld_ = dst_layer_d.blocking_desc().strides[1];
662     rnn.src_iter_ld_ = types::is_zero_md(src_iter_d.md_)
663             ? 0
664             : src_iter_d.blocking_desc().strides[2];
665     rnn.dst_iter_ld_ = types::is_zero_md(dst_iter_d.md_)
666             ? 0
667             : dst_iter_d.blocking_desc().strides[2];
668     rnn.src_iter_c_ld_ = types::is_zero_md(src_iter_c_d.md_)
669             ? 0
670             : src_iter_c_d.blocking_desc().strides[2];
671     rnn.dst_iter_c_ld_ = types::is_zero_md(dst_iter_c_d.md_)
672             ? 0
673             : dst_iter_c_d.blocking_desc().strides[2];
674 
675     /* Set the correct number of weights parts */
676     bool is_orig_gru = rd.cell_kind == alg_kind::vanilla_gru;
677     rnn.n_parts_weights_layer = 1;
678     rnn.parts_weights_layer[0] = rnn.n_gates;
679     rnn.parts_weights_layer[1] = 0;
680 
681     rnn.n_parts_weights_iter = is_orig_gru ? 2 : 1;
682     rnn.parts_weights_iter[0] = is_orig_gru ? 2 : rnn.n_gates;
683     rnn.parts_weights_iter[1] = is_orig_gru ? 1 : 0;
684 
685     rnn.n_parts_weights_projection = 1;
686     rnn.parts_weights_projection[0] = 1;
687 
688     rnn.n_parts_bias = 1;
689     rnn.parts_bias[0] = rnn.n_bias;
690     rnn.parts_bias[1] = 0;
691 
692     /* Decide wich gemm implementation to use: packed/nonpacked jit/cblas
693      * and if to mergre gemm across iterations */
694     bool is_f32 = rnn.dt_conf == all_f32, is_bf16 = rnn.dt_conf == all_bf16;
695     bool is_gru = utils::one_of(
696             rd.cell_kind, alg_kind::vanilla_gru, alg_kind::lbr_gru);
697     bool is_inference = !rnn.is_training;
698 
699     // To be able to merge the GEMM on the layer input when not
700     // copying, we need to have a trivial stride for the T dimension
701     auto src_layer_is_trivial_stride = src_layer_d.blocking_desc().strides[0]
702             == (rnn.src_layer_ld_ * rnn.mb);
703     auto dst_layer_is_trivial_stride = dst_layer_d.blocking_desc().strides[0]
704             == (rnn.dst_layer_ld_ * rnn.mb);
705 
706     rnn.merge_gemm_layer = (!rnn.is_brgemm)
707             ? ((rnn.is_fwd && src_layer_is_trivial_stride)
708                       || ((rd.prop_kind == prop_kind::backward)
709                               && dst_layer_is_trivial_stride))
710                     && (((rnn.is_fwd && rnn.mb < 128) || !rnn.is_fwd)
711                             || rnn.is_int8())
712             : false;
713     rnn.merge_gemm_iter = (!rnn.is_brgemm)
714             ? dst_layer_is_trivial_stride && !(rnn.is_fwd || is_gru)
715             : false;
716     rnn.force_nocopy = false;
717 #if DNNL_X64
718     rnn.force_nocopy = !x64::mayiuse(x64::avx512_mic) && x64::mayiuse(x64::avx)
719             && ((is_inference && (rnn.n_layer > 1 || rnn.mb < 100))
720                     || (rnn.is_training && rnn.dhc < 500));
721 #endif
722 
723     /* Decide to copy bias */
724     rnn.copy_bias = rnn.is_int8();
725 
726     rnn.use_layer_packed_gemm = !rnn.is_brgemm
727             ? utils::one_of(weights_layer_d.format_kind(), format_kind::any,
728                       format_kind::rnn_packed)
729                     && is_inference
730                     && ((is_f32 && pack_sgemm_supported() && rnn.n_iter == 1)
731                             || rnn.is_int8() || is_bf16)
732             : false;
733     rnn.use_iter_packed_gemm = !rnn.is_brgemm
734             ? utils::one_of(weights_iter_d.format_kind(), format_kind::any,
735                       format_kind::rnn_packed)
736                     && is_inference
737                     && ((is_f32 && pack_sgemm_supported() && rnn.mb >= 16)
738                             || rnn.is_int8() || is_bf16)
739             : false;
740     rnn.use_projection_packed_gemm = !rnn.is_brgemm
741             ? utils::one_of(weights_projection_d.format_kind(),
742                       format_kind::any, format_kind::rnn_packed)
743                     && is_inference
744                     && ((is_f32 && pack_sgemm_supported() && rnn.n_iter == 1)
745                             || rnn.is_int8() || is_bf16)
746             : false;
747 
748     /* Set packed gemm sizes */
749     /* TODO: investigate the benefit of mixing packed and non-packed weights parts */
750     auto set_pack_sizes
751             = [&](bool merge, bool &do_pack, size_t &weights_pack_size,
752                       int &n_parts, int *parts, size_t *parts_pack_size,
753                       size_t &comp_offset, int ic, int oc, int weights_oc,
754                       dim_t data_ld) -> bool {
755         bool pack = true;
756         weights_pack_size = 0;
757         for (int p = 0; p < n_parts; p++) {
758             dim_t m_p = rnn.is_fwd ? (parts[p] * oc) : ic;
759             dim_t k_p = rnn.is_fwd ? ic : (parts[p] * oc);
760             dim_t n_p = merge ? rnn.mb * rnn.n_iter : rnn.mb;
761             bool pack_part = true;
762 
763             dnnl_status_t st = dnnl_success;
764             switch (rnn.dt_conf) {
765                 case all_f32:
766                     st = sgemm_pack_get_size("A", "N", "N", &m_p, &n_p, &k_p,
767                             &m_p, &data_ld, &parts_pack_size[p], &pack_part);
768                     break;
769                 case s8s8s8f32:
770                 case f32s8f32f32:
771                 case s8s8s8s8:
772                 case f32s8f32s8:
773                     st = gemm_s8u8s32_pack_get_size("A", "N", "N", &m_p, &n_p,
774                             &k_p, &m_p, &data_ld, &parts_pack_size[p],
775                             &pack_part);
776                     break;
777                 case u8u8u8f32:
778                 case f32u8f32f32:
779                 case u8u8u8u8:
780                 case f32u8f32u8:
781                     st = gemm_s8u8s32_pack_get_size("A", "N", "N", &m_p, &n_p,
782                             &k_p, &m_p, &data_ld, &parts_pack_size[p],
783                             &pack_part);
784                     break;
785                 case all_bf16:
786                     st = gemm_bf16bf16f32_pack_get_size("A", "N", "N", &m_p,
787                             &n_p, &k_p, &m_p, &data_ld, &parts_pack_size[p],
788                             &pack_part);
789                     break;
790                 default: assert(!"Unsupported configuration");
791             }
792             if (st != dnnl_success) return false;
793 
794             pack = pack && pack_part;
795             weights_pack_size += rnn.n_layer * rnn.n_dir * parts_pack_size[p];
796         }
797 
798         // NOTE: pack is updated only for f32. We force pack for int8
799         do_pack = (rnn.dt_conf == all_f32) ? pack : true;
800         comp_offset = weights_pack_size;
801         const bool need_compensation = rnn.is_int8();
802         weights_pack_size += (need_compensation ? rnn.n_layer * rnn.n_dir : 0)
803                 * weights_oc * sizeof(float);
804 
805         return true;
806     };
807     // TODO: the activation leading dimension can vary for first layer/iteration
808     if (rnn.use_layer_packed_gemm) {
809         bool ok = set_pack_sizes(rnn.merge_gemm_layer,
810                 rnn.use_layer_packed_gemm, rnn.weights_layer_pack_size,
811                 rnn.n_parts_weights_layer, rnn.parts_weights_layer,
812                 rnn.part_weights_layer_pack_size, rnn.weights_layer_comp_offset,
813                 rnn.slc, rnn.dhc, rnn.n_gates * rnn.dhc,
814                 rnn.ws_states_layer_ld);
815         if (!ok) return false;
816     }
817 
818     if (rnn.use_iter_packed_gemm) {
819         bool ok = set_pack_sizes(rnn.merge_gemm_iter, rnn.use_iter_packed_gemm,
820                 rnn.weights_iter_pack_size, rnn.n_parts_weights_iter,
821                 rnn.parts_weights_iter, rnn.part_weights_iter_pack_size,
822                 rnn.weights_iter_comp_offset, rnn.sic, rnn.dhc,
823                 rnn.n_gates * rnn.dhc, rnn.ws_states_iter_ld);
824         if (!ok) return false;
825     }
826 
827     if (rnn.use_projection_packed_gemm) {
828         bool ok = set_pack_sizes(false, rnn.use_projection_packed_gemm,
829                 rnn.weights_projection_pack_size,
830                 rnn.n_parts_weights_projection, rnn.parts_weights_projection,
831                 rnn.part_weights_projection_pack_size,
832                 rnn.weights_projection_comp_offset, rnn.dhc, rnn.dic, rnn.dic,
833                 rnn.scratch_ht_ld);
834         if (!ok) return false;
835     }
836 
837     return true;
838 }
839 
840 template <typename T>
set_conf(rnn_conf_t & rnn,const rnn_desc_t & rd,const memory_desc_wrapper & weights_layer_d,const memory_desc_wrapper & weights_iter_d,const memory_desc_wrapper & weights_projection_d,const memory_desc_wrapper & diff_weights_layer_d,const memory_desc_wrapper & diff_weights_iter_d,const memory_desc_wrapper & diff_weights_projection_d)841 void set_conf(rnn_conf_t &rnn, const rnn_desc_t &rd,
842         const memory_desc_wrapper &weights_layer_d,
843         const memory_desc_wrapper &weights_iter_d,
844         const memory_desc_wrapper &weights_projection_d,
845         const memory_desc_wrapper &diff_weights_layer_d,
846         const memory_desc_wrapper &diff_weights_iter_d,
847         const memory_desc_wrapper &diff_weights_projection_d) {
848 
849     // Set leading dimensions for input weights arrays depending on input format
850     auto set_dims = [&](const memory_desc_wrapper &md, int &ld, int &nld) {
851         ld = 0;
852         nld = 0;
853         if (md.is_blocking_desc()) {
854             if (is_ldigo(md)) {
855                 ld = (int)md.blocking_desc().strides[2];
856                 nld = md.dims()[2];
857             } else if (is_ldgoi(md)) {
858                 ld = (int)md.blocking_desc().strides[4];
859                 nld = md.dims()[3] * md.dims()[4];
860             } else if (is_ldoi(md)) {
861                 ld = (int)md.blocking_desc().strides[3];
862                 nld = md.dims()[3];
863             } else if (is_ldio(md)) {
864                 ld = (int)md.blocking_desc().strides[2];
865                 nld = md.dims()[2];
866             } else
867                 assert(!"unsupported weights format");
868         }
869     };
870     set_dims(weights_layer_d, rnn.weights_layer_ld, rnn.weights_layer_nld);
871     set_dims(weights_iter_d, rnn.weights_iter_ld, rnn.weights_iter_nld);
872     set_dims(weights_projection_d, rnn.weights_projection_ld,
873             rnn.weights_projection_nld);
874     if (!rnn.is_fwd) {
875         set_dims(diff_weights_layer_d, rnn.diff_weights_layer_ld,
876                 rnn.diff_weights_layer_nld);
877         set_dims(diff_weights_iter_d, rnn.diff_weights_iter_ld,
878                 rnn.diff_weights_iter_nld);
879         set_dims(diff_weights_projection_d, rnn.diff_weights_projection_ld,
880                 rnn.diff_weights_projection_nld);
881     }
882 
883     assert(weights_layer_d.data_type() == weights_iter_d.data_type());
884     assert(IMPLICATION(diff_weights_layer_d.ndims() != 0,
885             (diff_weights_layer_d.data_type()
886                     == diff_weights_iter_d.data_type())));
887 
888     /* Set workspace sizes to store:
889      * states to compute a pass
890      * diff states to compute bwd pass (training onl)y
891      * intermediate results from the gates
892      */
893 
894     assert(sizeof(typename T::src_layer_t) == sizeof(typename T::dst_layer_t));
895     assert(sizeof(typename T::src_iter_t) == sizeof(typename T::dst_iter_t));
896 
897     rnn.use_workspace = rnn.is_training;
898     // TODO: for inference, we can make ws_states_* smaller, but
899     // dependant of the grid execution though
900     rnn.ws_states_layer_size = (size_t)(rnn.n_layer + 1) * rnn.n_dir
901             * (rnn.n_iter + 1) * rnn.mb * rnn.ws_states_layer_ld
902             * sizeof(typename T::src_layer_t);
903     rnn.ws_states_iter_size = (size_t)(rnn.n_layer + 1) * rnn.n_dir
904             * (rnn.n_iter + 1) * rnn.mb * rnn.ws_states_iter_ld
905             * sizeof(typename T::src_iter_t);
906     bool is_lstm = rd.cell_kind == dnnl_vanilla_lstm;
907     rnn.ws_states_iter_c_size = is_lstm
908             ? (size_t)(rnn.n_layer + 1) * rnn.n_dir * (rnn.n_iter + 1) * rnn.mb
909                     * rnn.ws_states_iter_c_ld * sizeof(float)
910             : 0;
911 
912     rnn.ws_diff_states_layer_size = rnn.is_training
913             ? (size_t)(rnn.n_layer + 1) * rnn.n_dir * (rnn.n_iter + 1) * rnn.mb
914                     * rnn.ws_diff_states_layer_ld
915                     * sizeof(typename T::gemm_acc_t)
916             : (size_t)0;
917     rnn.ws_diff_states_iter_size = rnn.is_training
918             ? (size_t)(rnn.n_layer + 1) * rnn.n_dir * (rnn.n_iter + 1) * rnn.mb
919                     * rnn.ws_diff_states_iter_ld
920                     * sizeof(typename T::gemm_acc_t)
921             : (size_t)0;
922     rnn.ws_diff_states_iter_c_size = rnn.is_training && is_lstm
923             ? (size_t)(rnn.n_layer + 1) * rnn.n_dir * (rnn.n_iter + 1) * rnn.mb
924                     * rnn.ws_diff_states_iter_c_ld
925                     * sizeof(typename T::gemm_acc_t)
926             : (size_t)0;
927 
928     rnn.ws_gates_size = rnn.is_training
929             ? (size_t)rnn.n_layer * rnn.n_dir * rnn.n_iter * rnn.ws_gates_nld
930                     * rnn.ws_gates_ld * sizeof(typename T::gates_t)
931             : (size_t)0;
932     rnn.ws_ht_size = rnn.is_training
933             ? (size_t)rnn.n_layer * rnn.n_dir * rnn.n_iter * rnn.ws_ht_nld
934                     * rnn.ws_ht_ld * sizeof(typename T::dst_iter_t)
935             : (size_t)0;
936     rnn.n_iter_scratch_gates
937             = (rnn.merge_gemm_layer || rnn.merge_gemm_iter) ? rnn.n_iter : 1;
938     rnn.scratch_gates_size = rnn.n_iter_scratch_gates * rnn.scratch_gates_nld
939             * rnn.scratch_gates_ld * sizeof(typename T::scratch_t);
940     rnn.scratch_ht_size
941             = rnn.scratch_ht_nld * rnn.scratch_ht_ld * sizeof(typename T::ht_t);
942     rnn.scratch_diff_ht_size = rnn.is_training ? rnn.scratch_diff_ht_nld
943                     * rnn.scratch_diff_ht_ld * sizeof(typename T::gemm_acc_t)
944                                                : (size_t)0;
945 
946     /* set other sizes */
947     /// scratchpad buffer for each cell to hold intermediate data in gru/lbr_gru
948     rnn.scratch_cell_size = rnn.is_lbr
949             ? (size_t)rnn.scratch_gates_nld * rnn.scratch_gates_ld
950                     * sizeof(typename T::gemm_acc_t)
951             : (rd.cell_kind == alg_kind::vanilla_gru
952                             ? (size_t)rnn.ws_states_layer_nld
953                                     * rnn.ws_states_layer_ld
954                                     * sizeof(typename T::gemm_acc_t)
955                             : 0);
956     /// workspace needed for lbr GRU
957     rnn.ws_per_cell = (size_t)rnn.is_lbr * rnn.mb * rnn.dhc
958             * sizeof(typename T::gemm_acc_t);
959     rnn.ws_grid_comp_size = (size_t)rnn.is_lbr * rnn.is_training * rnn.n_layer
960             * rnn.n_dir * rnn.n_iter * rnn.ws_per_cell * sizeof(float);
961     /// bias ws needed to add compensation in int8
962     rnn.ws_bias_size = (size_t)rnn.n_layer * rnn.n_dir * rnn.n_bias * rnn.dhc
963             * sizeof(float);
964 }
965 
966 void set_offsets(const rnn_conf_t &rnn, size_t &ws_gates_offset,
967         size_t &ws_ht_offset, size_t &ws_state_layer_offset,
968         size_t &ws_states_iter_offset, size_t &ws_states_iter_c_offset,
969         size_t &ws_diff_states_layer_offset, size_t &ws_diff_states_iter_offset,
970         size_t &ws_diff_states_iter_c_offset, size_t &ws_grid_comp_offset,
971         size_t &ws_bias_offset, size_t &scratch_gates_offset,
972         size_t &scratch_ht_offset, size_t &scratch_diff_ht_offset,
973         size_t &scratch_cell_offset, size_t &scratchpad_size,
974         size_t &workspace_size);
975 
976 void get_scratchpad_and_workspace_sizes(
977         const rnn_conf_t &rnn, size_t &scratchpad_size, size_t &workspace_size);
978 status_t set_expected_desc(rnn_conf_t &rnn, memory_desc_t &weights_md,
979         weights_type_t weights_type);
980 status_t set_good_strides(memory_desc_t &weights_md, format_tag_t tag);
981 
982 template <typename T>
983 struct ws_gates_aoc {
ws_gates_aocdnnl::impl::cpu::rnn_utils::ws_gates_aoc984     ws_gates_aoc(const rnn_conf_t &rnn, T *data)
985         : gates_(data, rnn.ws_gates_nld, rnn.ws_gates_ld), DHC_(rnn.dhc) {}
operator ()dnnl::impl::cpu::rnn_utils::ws_gates_aoc986     T &operator()(int batch, int gate, int dhc) const {
987         return gates_(batch, gate * DHC_ + dhc);
988     }
989 
990 private:
991     const dnnl::impl::utils::array_offset_calculator<T, 2> gates_;
992     const int DHC_;
993 };
994 using ws_gates_aoc_t = ws_gates_aoc<float>;
995 using ws_gates_aoc_s32_t = ws_gates_aoc<int32_t>;
996 
997 template <typename T>
998 struct ws_ht_aoc {
ws_ht_aocdnnl::impl::cpu::rnn_utils::ws_ht_aoc999     ws_ht_aoc(const rnn_conf_t &rnn, T *data)
1000         : ht_(data, rnn.ws_ht_nld, rnn.ws_ht_ld) {}
operator ()dnnl::impl::cpu::rnn_utils::ws_ht_aoc1001     T &operator()(int batch, int dhc) const { return ht_(batch, dhc); }
1002 
1003 private:
1004     const dnnl::impl::utils::array_offset_calculator<T, 2> ht_;
1005 };
1006 
1007 template <typename T>
1008 struct scratch_gates_aoc {
scratch_gates_aocdnnl::impl::cpu::rnn_utils::scratch_gates_aoc1009     scratch_gates_aoc(const rnn_conf_t &rnn, T *data)
1010         : gates_(data, rnn.scratch_gates_nld, rnn.scratch_gates_ld)
1011         , DHC_(rnn.dhc) {}
operator ()dnnl::impl::cpu::rnn_utils::scratch_gates_aoc1012     T &operator()(int batch, int gate, int dhc) const {
1013         return gates_(batch, gate * DHC_ + dhc);
1014     }
1015 
1016 private:
1017     const dnnl::impl::utils::array_offset_calculator<T, 2> gates_;
1018     const int DHC_;
1019 };
1020 using scratch_gates_aoc_t = scratch_gates_aoc<float>;
1021 using scratch_gates_aoc_s32_t = scratch_gates_aoc<int32_t>;
1022 
1023 template <typename T>
1024 struct scratch_ht_aoc {
scratch_ht_aocdnnl::impl::cpu::rnn_utils::scratch_ht_aoc1025     scratch_ht_aoc(const rnn_conf_t &rnn, T *data)
1026         : ht_(data, rnn.scratch_ht_nld, rnn.scratch_ht_ld) {}
operator ()dnnl::impl::cpu::rnn_utils::scratch_ht_aoc1027     T &operator()(int batch, int dhc) const { return ht_(batch, dhc); }
1028 
1029 private:
1030     const dnnl::impl::utils::array_offset_calculator<T, 2> ht_;
1031 };
1032 using scratch_ht_aoc_t = scratch_ht_aoc<float>;
1033 using scratch_ht_aoc_s32_t = scratch_ht_aoc<int32_t>;
1034 
1035 template <typename T>
1036 struct weights_peephole_aoc_t {
weights_peephole_aoc_tdnnl::impl::cpu::rnn_utils::weights_peephole_aoc_t1037     weights_peephole_aoc_t(const rnn_conf_t &rnn, T *data)
1038         : weights_peephole_(data, 3, rnn.dhc) {}
operator ()dnnl::impl::cpu::rnn_utils::weights_peephole_aoc_t1039     T &operator()(int g, int dhc) const { return weights_peephole_(g, dhc); }
1040 
1041 private:
1042     const utils::array_offset_calculator<T, 2> weights_peephole_;
1043 };
1044 
1045 struct bias_aoc_t {
bias_aoc_tdnnl::impl::cpu::rnn_utils::bias_aoc_t1046     bias_aoc_t(const rnn_conf_t &rnn, const float *data)
1047         : bias_(data, rnn.n_bias, rnn.dhc) {}
operator ()dnnl::impl::cpu::rnn_utils::bias_aoc_t1048     const float &operator()(int bias_n, int dhc) const {
1049         return bias_(bias_n, dhc);
1050     }
1051 
1052 private:
1053     const dnnl::impl::utils::array_offset_calculator<const float, 2> bias_;
1054 };
1055 
1056 template <typename T>
1057 struct ws_states_layer_aoc {
ws_states_layer_aocdnnl::impl::cpu::rnn_utils::ws_states_layer_aoc1058     ws_states_layer_aoc(const rnn_conf_t &rnn, T *data, int leading_dim)
1059         : state_(data, rnn.ws_states_layer_nld, leading_dim) {}
ws_states_layer_aocdnnl::impl::cpu::rnn_utils::ws_states_layer_aoc1060     ws_states_layer_aoc(const rnn_conf_t &rnn, T *data)
1061         : state_(data, rnn.ws_states_layer_nld, rnn.ws_states_layer_ld) {}
operator ()dnnl::impl::cpu::rnn_utils::ws_states_layer_aoc1062     T &operator()(int batch, int dhc) const { return state_(batch, dhc); }
1063 
1064 private:
1065     const dnnl::impl::utils::array_offset_calculator<T, 2> state_;
1066 };
1067 
1068 template <typename T>
1069 struct ws_states_iter_aoc {
ws_states_iter_aocdnnl::impl::cpu::rnn_utils::ws_states_iter_aoc1070     ws_states_iter_aoc(const rnn_conf_t &rnn, T *data, int leading_dim)
1071         : state_(data, rnn.ws_states_iter_nld, leading_dim) {}
ws_states_iter_aocdnnl::impl::cpu::rnn_utils::ws_states_iter_aoc1072     ws_states_iter_aoc(const rnn_conf_t &rnn, T *data)
1073         : state_(data, rnn.ws_states_iter_nld, rnn.ws_states_iter_ld) {}
operator ()dnnl::impl::cpu::rnn_utils::ws_states_iter_aoc1074     T &operator()(int batch, int dhc) const { return state_(batch, dhc); }
1075 
1076 private:
1077     const dnnl::impl::utils::array_offset_calculator<T, 2> state_;
1078 };
1079 
1080 template <typename T>
1081 struct ws_states_iter_c_aoc {
ws_states_iter_c_aocdnnl::impl::cpu::rnn_utils::ws_states_iter_c_aoc1082     ws_states_iter_c_aoc(const rnn_conf_t &rnn, T *data, int leading_dim)
1083         : state_(data, rnn.ws_states_iter_c_nld, leading_dim) {}
ws_states_iter_c_aocdnnl::impl::cpu::rnn_utils::ws_states_iter_c_aoc1084     ws_states_iter_c_aoc(const rnn_conf_t &rnn, T *data)
1085         : state_(data, rnn.ws_states_iter_c_nld, rnn.ws_states_iter_c_ld) {}
operator ()dnnl::impl::cpu::rnn_utils::ws_states_iter_c_aoc1086     T &operator()(int batch, int dhc) const { return state_(batch, dhc); }
1087 
1088 private:
1089     const dnnl::impl::utils::array_offset_calculator<T, 2> state_;
1090 };
1091 
1092 template <typename T>
1093 struct ws_diff_states_layer_aoc {
ws_diff_states_layer_aocdnnl::impl::cpu::rnn_utils::ws_diff_states_layer_aoc1094     ws_diff_states_layer_aoc(const rnn_conf_t &rnn, T *data)
1095         : diff_states_layer_(data, rnn.ws_diff_states_layer_nld,
1096                 rnn.ws_diff_states_layer_ld) {}
operator ()dnnl::impl::cpu::rnn_utils::ws_diff_states_layer_aoc1097     T &operator()(int batch, int dhc) const {
1098         return diff_states_layer_(batch, dhc);
1099     }
1100 
1101 private:
1102     const dnnl::impl::utils::array_offset_calculator<T, 2> diff_states_layer_;
1103 };
1104 
1105 template <typename T>
1106 struct ws_diff_states_iter_aoc {
ws_diff_states_iter_aocdnnl::impl::cpu::rnn_utils::ws_diff_states_iter_aoc1107     ws_diff_states_iter_aoc(const rnn_conf_t &rnn, T *data)
1108         : diff_states_iter_(data, rnn.ws_diff_states_iter_nld,
1109                 rnn.ws_diff_states_iter_ld) {}
operator ()dnnl::impl::cpu::rnn_utils::ws_diff_states_iter_aoc1110     T &operator()(int batch, int dhc) const {
1111         return diff_states_iter_(batch, dhc);
1112     }
1113 
1114 private:
1115     const dnnl::impl::utils::array_offset_calculator<T, 2> diff_states_iter_;
1116 };
1117 
1118 template <typename T>
1119 struct ws_diff_states_iter_c_aoc {
ws_diff_states_iter_c_aocdnnl::impl::cpu::rnn_utils::ws_diff_states_iter_c_aoc1120     ws_diff_states_iter_c_aoc(const rnn_conf_t &rnn, T *data)
1121         : diff_states_iter_c_(data, rnn.ws_diff_states_iter_c_nld,
1122                 rnn.ws_diff_states_iter_c_ld) {}
operator ()dnnl::impl::cpu::rnn_utils::ws_diff_states_iter_c_aoc1123     T &operator()(int batch, int dhc) const {
1124         return diff_states_iter_c_(batch, dhc);
1125     }
1126 
1127 private:
1128     const dnnl::impl::utils::array_offset_calculator<T, 2> diff_states_iter_c_;
1129 };
1130 
1131 struct ws_diff_w_iter_aoc_t {
ws_diff_w_iter_aoc_tdnnl::impl::cpu::rnn_utils::ws_diff_w_iter_aoc_t1132     ws_diff_w_iter_aoc_t(const rnn_conf_t &rnn, float *data)
1133         : diff_weights_iter_(
1134                 data, rnn.diff_weights_iter_nld, rnn.diff_weights_iter_ld)
1135         , DHC_(rnn.dhc) {}
operator ()dnnl::impl::cpu::rnn_utils::ws_diff_w_iter_aoc_t1136     float &operator()(int sic, int gate, int dhc) const {
1137         return diff_weights_iter_(sic, gate * DHC_ + dhc);
1138     }
1139 
1140 private:
1141     const dnnl::impl::utils::array_offset_calculator<float, 2>
1142             diff_weights_iter_;
1143     const int DHC_;
1144 };
1145 
1146 } // namespace rnn_utils
1147 } // namespace cpu
1148 } // namespace impl
1149 } // namespace dnnl
1150 #endif
1151