1 /*******************************************************************************
2 * Copyright 2018-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef CPU_RNN_RNN_UTILS_HPP
18 #define CPU_RNN_RNN_UTILS_HPP
19
20 #include <type_traits>
21
22 #include "common/c_types_map.hpp"
23 #include "common/memory_desc_wrapper.hpp"
24 #include "common/primitive.hpp"
25 #include "common/utils.hpp"
26
27 #include "cpu/platform.hpp"
28
29 #include "cpu/gemm/gemm_pack.hpp"
30
31 #if DNNL_X64
32 #include "cpu/x64/cpu_isa_traits.hpp"
33 #endif
34
35 #define rnn_postgemm_sig(f) \
36 void f(const rnn_utils::rnn_conf_t &rnn, \
37 rnn_utils::cell_position_t cell_position, gates_t *ws_gates_, \
38 scratch_t *scratch_gates_, dst_layer_t *dst_layer_, \
39 float *dst_iter_c_, const src_iter_t *src_iter_, \
40 const float *src_iter_c_, gemm_acc_t *diff_src_layer_, \
41 gemm_acc_t *diff_src_iter_, gemm_acc_t *diff_src_iter_c_, \
42 gemm_acc_t *diff_dst_layer_, gemm_acc_t *diff_dst_iter_, \
43 gemm_acc_t *diff_dst_iter_c_, const float *weights_peephole_, \
44 float *bias_, gates_t *ws_grid_, scratch_t *scratch_cell_, \
45 dst_iter_t *dst_iter_, float *weights_scales_, int block_step) \
46 const
47
48 #if DNNL_X64
49 #define rnn_cell_execution_sig(f) \
50 dnnl_status_t f(const exec_ctx_t &ctx, const rnn_utils::rnn_conf_t &rnn, \
51 rnn_utils::cell_position_t cell_position, dst_layer_t *dst_layer_, \
52 float *dst_iter_c_, gemm_acc_t *diff_src_layer_, \
53 gemm_acc_t *diff_src_iter_, gemm_acc_t *diff_src_iter_c_, \
54 weights_t **w_layer_, weights_t **w_iter_, \
55 weights_t **w_projection_, const float *weights_peephole_, \
56 const float *w_proj_comp, float **bias_, \
57 const src_layer_t *src_layer_, const src_iter_t *src_iter_, \
58 const float *src_iter_c_, gemm_acc_t *diff_dst_layer_, \
59 gemm_acc_t *diff_dst_iter_, gemm_acc_t *diff_dst_iter_c_, \
60 gemm_acc_t *diff_w_layer_, gemm_acc_t *diff_w_iter_, \
61 float *diff_weights_projection_, float *diff_weights_peephole_, \
62 float *diff_bias_, gates_t *ws_gates_, scratch_t *scratch_gates_, \
63 ht_t *proj_ht_, gemm_acc_t *scratch_diff_ht_, gates_t *ws_grid_, \
64 scratch_t *scratch_cell_, scratch_t *scratch_gates_blocked_, \
65 scratch_t *scratch_src_layer_, scratch_t *scratch_src_iter_, \
66 dst_iter_t *dst_iter_, gemm_acc_t *amx_scratchpad, \
67 x64::brgemm_batch_element_t *addr_batch_global) const
68
69 #define rnn_grid_execution_sig(f) \
70 dnnl_status_t f(const exec_ctx_t &ctx, const rnn_utils::rnn_conf_t &rnn, \
71 weights_t **weights_layer_, weights_t **weights_iter_, \
72 weights_t **weights_projection_, const float *weights_peephole_, \
73 const float *w_proj_comp, float **bias_, \
74 const src_layer_t *src_layer_, const src_iter_t *src_iter_, \
75 const float *src_iter_c_, dst_layer_t *dst_layer_, \
76 dst_iter_t *dst_iter_, float *dst_iter_c_, \
77 src_layer_t *ws_states_layer_, src_iter_t *ws_states_iter_, \
78 float *ws_states_iter_c_, gemm_acc_t *ws_diff_states_layer_, \
79 gemm_acc_t *ws_diff_states_iter_, \
80 gemm_acc_t *ws_diff_states_iter_c_, gates_t *ws_gates_, \
81 ht_t *ws_ht_, gates_t *ws_grid_, scratch_t *scratch_gates_, \
82 ht_t *scratch_ht_, gemm_acc_t *scratch_diff_ht_, \
83 scratch_t *scratch_cell_, scratch_t *scratch_gates_blocked_, \
84 scratch_t *scratch_src_layer_, scratch_t *scratch_src_iter_, \
85 gemm_acc_t *diff_weights_layer_, gemm_acc_t *diff_weights_iter_, \
86 float *diff_weights_projection_, float *diff_weights_peephole_, \
87 float *diff_bias_, gemm_acc_t *amx_scratchpad, \
88 x64::brgemm_batch_element_t *addr_batch_global) const
89 #else
90 #define rnn_cell_execution_sig(f) \
91 dnnl_status_t f(const rnn_utils::rnn_conf_t &rnn, \
92 rnn_utils::cell_position_t cell_position, dst_layer_t *dst_layer_, \
93 float *dst_iter_c_, gemm_acc_t *diff_src_layer_, \
94 gemm_acc_t *diff_src_iter_, gemm_acc_t *diff_src_iter_c_, \
95 weights_t **w_layer_, weights_t **w_iter_, \
96 weights_t **w_projection_, const float *weights_peephole_, \
97 const float *w_proj_comp, float **bias_, \
98 const src_layer_t *src_layer_, const src_iter_t *src_iter_, \
99 const float *src_iter_c_, gemm_acc_t *diff_dst_layer_, \
100 gemm_acc_t *diff_dst_iter_, gemm_acc_t *diff_dst_iter_c_, \
101 gemm_acc_t *diff_w_layer_, gemm_acc_t *diff_w_iter_, \
102 float *diff_weights_projection_, float *diff_weights_peephole_, \
103 float *diff_bias_, gates_t *ws_gates_, scratch_t *scratch_gates_, \
104 ht_t *proj_ht_, gemm_acc_t *scratch_diff_ht_, gates_t *ws_grid_, \
105 scratch_t *scratch_cell_, dst_iter_t *dst_iter_, \
106 gemm_acc_t *amx_scratchpad) const
107
108 #define rnn_grid_execution_sig(f) \
109 dnnl_status_t f(const rnn_utils::rnn_conf_t &rnn, \
110 weights_t **weights_layer_, weights_t **weights_iter_, \
111 weights_t **weights_projection_, const float *weights_peephole_, \
112 const float *w_proj_comp, float **bias_, \
113 const src_layer_t *src_layer_, const src_iter_t *src_iter_, \
114 const float *src_iter_c_, dst_layer_t *dst_layer_, \
115 dst_iter_t *dst_iter_, float *dst_iter_c_, \
116 src_layer_t *ws_states_layer_, src_iter_t *ws_states_iter_, \
117 float *ws_states_iter_c_, gemm_acc_t *ws_diff_states_layer_, \
118 gemm_acc_t *ws_diff_states_iter_, \
119 gemm_acc_t *ws_diff_states_iter_c_, gates_t *ws_gates_, \
120 ht_t *ws_ht_, gates_t *ws_grid_, scratch_t *scratch_gates_, \
121 ht_t *scratch_ht_, gemm_acc_t *scratch_diff_ht_, \
122 scratch_t *scratch_cell_, gemm_acc_t *diff_weights_layer_, \
123 gemm_acc_t *diff_weights_iter_, float *diff_weights_projection_, \
124 float *diff_weights_peephole_, float *diff_bias_, \
125 gemm_acc_t *amx_scratchpad) const
126 #endif
127
128 #define rnn_gemm_sig(f) \
129 dnnl_status_t f(const char transA, const char transB, dim_t m, dim_t n, \
130 dim_t k, const float alpha, const weights_t *a_, const dim_t ldA, \
131 const gemm_data_t *b_, const dim_t ldB, const float beta, \
132 gemm_acc_t *c_, const dim_t ldC) const
133
134 #define rnn_bias_prepare_sig(f) \
135 void f(const rnn_utils::rnn_conf_t &rnn, float **bias_, const float *b_, \
136 float *scratch_bias_) const
137
138 #define rnn_bias_finalize_sig(f) \
139 void f(const rnn_utils::rnn_conf_t &rnn, float *scratch_bias_, \
140 const float *w_iter_comp, const float *w_layer_comp) const
141
142 #define rnn_weights_assign_sig(f) \
143 void f(const rnn_utils::rnn_conf_t &rnn, const memory_desc_t *md, \
144 int n_parts, const int *gates_per_part, weights_t **weights_, \
145 const weights_t *w_) const
146
147 namespace dnnl {
148 namespace impl {
149 namespace cpu {
150
151 namespace rnn_utils {
152
153 enum execution_direction_t {
154 l2r,
155 r2l,
156 bi_concat,
157 bi_sum,
158 };
159
160 enum cell_position_t {
161 middle_cell = 0x0,
162 first_layer = 0x1,
163 first_iter = 0x2,
164 last_layer = 0x4,
165 last_iter = 0x8,
166 c_state_first_iter = 0x10,
167 c_state_last_iter = 0x20
168 };
169
170 enum class weights_type_t {
171 layer,
172 iter,
173 projection,
174 peephole,
175 };
176
operator |=(cell_position_t & lhs,cell_position_t rhs)177 inline cell_position_t &operator|=(cell_position_t &lhs, cell_position_t rhs) {
178 lhs = static_cast<cell_position_t>(
179 static_cast<unsigned>(lhs) | static_cast<unsigned>(rhs));
180 return lhs;
181 }
182
operator |(cell_position_t lhs,cell_position_t rhs)183 inline cell_position_t operator|(cell_position_t lhs, cell_position_t rhs) {
184 return static_cast<cell_position_t>(
185 static_cast<unsigned>(lhs) | static_cast<unsigned>(rhs));
186 }
187
188 enum data_type_conf_t {
189 all_f32,
190 all_bf16,
191 u8u8u8f32,
192 f32u8f32f32,
193 u8u8u8u8,
194 f32u8f32u8,
195 s8s8s8f32,
196 f32s8f32f32,
197 s8s8s8s8,
198 f32s8f32s8
199 };
200
201 struct diff_src_brgemm_conf_t {
202 dim_t M = 0, N = 0, K = 0;
203
204 dim_t n_block = 0, N_blocks = 0, n_tail = 0;
205 dim_t m_block = 0, M_blocks = 0;
206
207 dim_t K_blocks = 0, k_block = 0, k_tail = 0;
208 dim_t Kpadded = 0;
209
210 dim_t N_iter = 0, N_layer = 0;
211 dim_t N_layer_blocks = 0, n_layer_tail = 0;
212 dim_t N_iter_blocks = 0, n_iter_tail = 0;
213 dim_t LDA = 0, LDB = 0, LDC = 0;
214
215 #if DNNL_X64
216 x64::cpu_isa_t isa = x64::isa_any;
217 #endif
218 };
219
220 struct diff_wei_brgemm_conf_t {
221 dim_t M = 0, M_layer = 0, M_iter = 0, N = 0, K = 0;
222
223 dim_t n_block = 0, N_blocks = 0, n_tail = 0;
224 dim_t m_block = 0, M_blocks = 0;
225 dim_t K_blocks = 0, k_block = 0, k_tail = 0;
226 dim_t Kpadded = 0;
227
228 dim_t LDA_layer = 0, LDA_iter = 0, LDB = 0, LDC_iter = 0, LDC_layer = 0;
229
230 bool global_transpose = false;
231
232 #if DNNL_X64
233 x64::cpu_isa_t isa = x64::isa_any;
234 #endif
235 };
236
237 struct rnn_conf_t {
238 execution_direction_t exec_dir;
239 data_type_conf_t dt_conf;
240 int n_layer = 0, n_iter = 0, n_dir = 0, n_gates = 0, n_states = 0;
241 int mb = 0;
242 int slc = 0, sic = 0, dhc = 0, dic = 0, dlc = 0;
243 //int gates_ld, gates_nld, gates_ws_ld;
244
245 int n_parts_weights_layer = 0;
246 int parts_weights_layer[DNNL_RNN_MAX_N_PARTS];
247 size_t part_weights_layer_pack_size[DNNL_RNN_MAX_N_PARTS];
248
249 int n_parts_weights_iter = 0;
250 int parts_weights_iter[DNNL_RNN_MAX_N_PARTS];
251 size_t part_weights_iter_pack_size[DNNL_RNN_MAX_N_PARTS];
252
253 int n_parts_weights_projection = 0;
254 int parts_weights_projection[DNNL_RNN_MAX_N_PARTS];
255 size_t part_weights_projection_pack_size[DNNL_RNN_MAX_N_PARTS];
256
257 int n_bias = 0, n_parts_bias = 0, parts_bias[DNNL_RNN_MAX_N_PARTS];
258
259 /* Size of packed data in bytes */
260 size_t weights_layer_comp_offset = 0, weights_layer_pack_size = 0;
261 size_t weights_iter_comp_offset = 0, weights_iter_pack_size = 0;
262 size_t weights_projection_comp_offset = 0, weights_projection_pack_size = 0;
263
264 bool copy_bias = 0;
265 int weights_layer_ld = 0, weights_layer_nld = 0;
266 int diff_weights_layer_ld = 0, diff_weights_layer_nld = 0;
267 int weights_iter_ld = 0, weights_iter_nld = 0;
268 int diff_weights_iter_ld = 0, diff_weights_iter_nld = 0;
269 int weights_projection_ld = 0, weights_projection_nld = 0;
270 int diff_weights_projection_ld = 0, diff_weights_projection_nld = 0;
271
272 int proj_ht_ld = 0, proj_ht_nld = 0;
273
274 int ws_gates_ld = 0, ws_gates_nld = 0;
275 int ws_ht_ld = 0, ws_ht_nld = 0;
276 int ws_states_layer_ld = 0, ws_states_layer_nld = 0;
277 int ws_states_iter_ld = 0, ws_states_iter_nld = 0;
278 int ws_states_iter_c_ld = 0, ws_states_iter_c_nld = 0;
279 int ws_diff_states_layer_ld = 0, ws_diff_states_layer_nld = 0;
280 int ws_diff_states_iter_ld = 0, ws_diff_states_iter_nld = 0;
281 int ws_diff_states_iter_c_ld = 0, ws_diff_states_iter_c_nld = 0;
282
283 int scratch_gates_ld = 0, scratch_gates_nld = 0;
284 int scratch_ht_ld = 0, scratch_ht_nld = 0;
285 int scratch_diff_ht_ld = 0, scratch_diff_ht_nld = 0;
286
287 int src_layer_ld_ = 0, src_layer_nld_ = 0;
288 int src_iter_ld_ = 0, src_iter_nld_ = 0;
289 int src_iter_c_ld_ = 0, src_iter_c_nld_ = 0;
290 int dst_layer_ld_ = 0, dst_layer_nld_ = 0;
291 int dst_iter_ld_ = 0, dst_iter_nld_ = 0;
292 int dst_iter_c_ld_ = 0, dst_iter_c_nld_ = 0;
293
294 int weights_iter_compensation_size = 0, weights_layer_compensation_size = 0;
295 bool is_fwd = 0, is_training = 0, is_lbr = 0, is_lstm_peephole = 0,
296 is_lstm_projection = 0;
297 bool use_workspace = 0;
298
299 // Size of workspace for each tensor in bytes
300 // Notes:
301 // 1. For non-LSTMP ws_states_iter_size == ws_states_layer_size. The corresponding
302 // pointers should point to the same places.
303 size_t ws_gates_size = 0;
304 size_t ws_ht_size = 0;
305 size_t ws_states_layer_size = 0;
306 size_t ws_states_iter_size = 0;
307 size_t ws_states_iter_c_size = 0;
308 size_t ws_diff_states_layer_size = 0;
309 size_t ws_diff_states_iter_size = 0;
310 size_t ws_diff_states_iter_c_size = 0;
311 size_t scratch_gates_size = 0;
312
313 size_t scratch_gates_blocked_size = 0;
314 size_t scratch_gates_blocked_nested_reorder_size = 0;
315 size_t scratch_src_layer_size = 0;
316 size_t scratch_src_layer_nested_reorder_size = 0;
317 size_t scratch_src_iter_size = 0;
318 size_t scratch_src_iter_nested_reorder_size = 0;
319
320 size_t scratch_ht_size = 0;
321 size_t scratch_diff_ht_size = 0;
322 size_t scratch_cell_size = 0;
323 size_t ws_grid_comp_size = 0;
324 size_t ws_per_cell = 0;
325 size_t ws_bias_size = 0;
326
327 bool merge_gemm_iter = false, merge_gemm_layer = false,
328 force_nocopy = false, use_layer_packed_gemm = false,
329 use_iter_packed_gemm = false, use_projection_packed_gemm = false;
330 int n_iter_scratch_gates = 0;
331
is_int8dnnl::impl::cpu::rnn_utils::rnn_conf_t332 inline bool is_int8() const {
333 return is_signed_int8() || is_unsigned_int8();
334 }
is_signed_int8dnnl::impl::cpu::rnn_utils::rnn_conf_t335 inline bool is_signed_int8() const {
336 return utils::one_of(
337 dt_conf, s8s8s8f32, f32s8f32f32, s8s8s8s8, f32s8f32s8);
338 }
is_unsigned_int8dnnl::impl::cpu::rnn_utils::rnn_conf_t339 inline bool is_unsigned_int8() const {
340 return utils::one_of(
341 dt_conf, u8u8u8f32, f32u8f32f32, u8u8u8u8, f32u8f32u8);
342 }
is_int8_amxdnnl::impl::cpu::rnn_utils::rnn_conf_t343 inline bool is_int8_amx() const {
344 #if DNNL_X64
345 return brgemm_isa == x64::avx512_core_bf16_amx_int8 && is_int8();
346 #else
347 return false;
348 #endif
349 }
is_bf16dnnl::impl::cpu::rnn_utils::rnn_conf_t350 inline bool is_bf16() const { return dt_conf == all_bf16; }
is_bf16_amxdnnl::impl::cpu::rnn_utils::rnn_conf_t351 inline bool is_bf16_amx() const {
352 #if DNNL_X64
353 return brgemm_isa == x64::avx512_core_bf16_amx_bf16 && is_bf16();
354 #else
355 return false;
356 #endif
357 }
is_f32dnnl::impl::cpu::rnn_utils::rnn_conf_t358 inline bool is_f32() const { return dt_conf == all_f32; }
359
skip_src_layer_copydnnl::impl::cpu::rnn_utils::rnn_conf_t360 inline bool skip_src_layer_copy() const {
361 // Note: this currently always returns true
362 return (exec_dir == l2r)
363 && utils::one_of(dt_conf, s8s8s8f32, f32s8f32f32, s8s8s8s8,
364 f32s8f32s8, u8u8u8u8, u8u8u8f32, f32u8f32u8,
365 f32u8f32f32, all_f32, all_bf16);
366 }
skip_src_iter_copydnnl::impl::cpu::rnn_utils::rnn_conf_t367 inline bool skip_src_iter_copy() const {
368 return (exec_dir == l2r) && (src_iter_ld_ > 0)
369 && utils::one_of(dt_conf, s8s8s8s8, s8s8s8f32, u8u8u8u8,
370 u8u8u8f32, all_f32, all_bf16);
371 }
skip_dst_layer_copydnnl::impl::cpu::rnn_utils::rnn_conf_t372 inline bool skip_dst_layer_copy() const {
373 return (exec_dir == l2r)
374 && utils::one_of(dt_conf, s8s8s8s8, f32s8f32s8, u8u8u8u8,
375 f32u8f32u8, all_f32, all_bf16);
376 }
skip_dst_iter_copydnnl::impl::cpu::rnn_utils::rnn_conf_t377 inline bool skip_dst_iter_copy() const {
378 return (exec_dir == l2r) && (dst_iter_ld_ > 0)
379 && utils::one_of(dt_conf, s8s8s8s8, s8s8s8f32, u8u8u8u8,
380 u8u8u8f32, all_f32, all_bf16);
381 }
382
src_layer_lddnnl::impl::cpu::rnn_utils::rnn_conf_t383 inline dim_t src_layer_ld(cell_position_t cell_position) const {
384 return (cell_position & first_layer) && skip_src_layer_copy()
385 ? src_layer_ld_
386 : (cell_position & last_iter) && skip_dst_iter_copy()
387 ? dst_iter_ld_
388 : ws_states_layer_ld;
389 }
390
src_iter_lddnnl::impl::cpu::rnn_utils::rnn_conf_t391 inline dim_t src_iter_ld(cell_position_t cell_position) const {
392 return (cell_position & first_iter) && skip_src_iter_copy()
393 ? src_iter_ld_
394 : ((cell_position & last_layer) && skip_dst_layer_copy()
395 && !(cell_position & first_iter)
396 ? dst_layer_ld_
397 : ws_states_iter_ld);
398 }
399
layer_brgemm_descdnnl::impl::cpu::rnn_utils::rnn_conf_t400 inline dim_t layer_brgemm_desc(cell_position_t cell_position) const {
401 return ((cell_position & first_layer) && skip_src_layer_copy())
402 ? 0
403 : ((cell_position & last_iter) && skip_dst_iter_copy()) ? 1 : 2;
404 }
405
iter_brgemm_descdnnl::impl::cpu::rnn_utils::rnn_conf_t406 inline dim_t iter_brgemm_desc(cell_position_t cell_position) const {
407 return ((cell_position & first_iter) && skip_src_iter_copy())
408 ? 0
409 : ((cell_position & last_layer) && skip_dst_layer_copy()
410 && !(cell_position & first_iter))
411 ? 1
412 : 2;
413 }
414
src_iter_c_lddnnl::impl::cpu::rnn_utils::rnn_conf_t415 inline dim_t src_iter_c_ld(cell_position_t cell_position) const {
416 return (cell_position & c_state_first_iter) ? src_iter_c_ld_
417 : ws_states_iter_c_ld;
418 }
419
dst_layer_lddnnl::impl::cpu::rnn_utils::rnn_conf_t420 inline dim_t dst_layer_ld(
421 cell_position_t cell_position, bool after_proj = false) const {
422 // We use scratch_ht and not dst_layer for lstmp
423 if (is_lstm_projection && !after_proj) return scratch_ht_ld;
424
425 return (cell_position & last_layer) && skip_dst_layer_copy()
426 ? dst_layer_ld_
427 : (cell_position & last_iter) && skip_dst_iter_copy()
428 ? dst_iter_ld_
429 : ws_states_layer_ld;
430 }
431
dst_brgemm_descdnnl::impl::cpu::rnn_utils::rnn_conf_t432 inline dim_t dst_brgemm_desc(
433 cell_position_t cell_position, bool after_proj = false) const {
434 // We use scratch_ht and not dst_layer for lstmp
435 if (is_lstm_projection && !after_proj) return 0;
436
437 return (cell_position & last_layer) && skip_dst_layer_copy()
438 ? 1
439 : (cell_position & last_iter) && skip_dst_iter_copy() ? 2 : 3;
440 }
441
dst_iter_lddnnl::impl::cpu::rnn_utils::rnn_conf_t442 inline dim_t dst_iter_ld(cell_position_t cell_position) const {
443 return (cell_position & last_iter) && skip_dst_iter_copy()
444 ? dst_iter_ld_
445 : ws_states_iter_ld;
446 }
447
dst_iter_c_lddnnl::impl::cpu::rnn_utils::rnn_conf_t448 inline dim_t dst_iter_c_ld(cell_position_t cell_position) const {
449 return (cell_position & c_state_last_iter) ? dst_iter_c_ld_
450 : ws_states_iter_c_ld;
451 }
452
453 // // when skipping copy, the output ld can be states_ws_ld,
454 // // dst_iter_ld or dst_layer_ld depending on the cell position
455 // inline dim_t dst_ld(cell_position_t cell_position) const {
456 // return (cell_position & last_layer) ? dst_layer_ld(cell_position)
457 // : dst_iter_ld(cell_position);
458 // }
dst_copy_lddnnl::impl::cpu::rnn_utils::rnn_conf_t459 inline dim_t dst_copy_ld(cell_position_t cell_position) const {
460 return dst_iter_ld(cell_position);
461 }
462
need_gemm_layerdnnl::impl::cpu::rnn_utils::rnn_conf_t463 inline bool need_gemm_layer(cell_position_t cell_position) const {
464 // In case of merge_gemm_layer we might still need a layer gemm if we store
465 // the states of the last iteration in the destination memory. The
466 // exception of this rule is the first layer though, in which case all
467 // states are kept in user's src_layer, hence making full merged gemm
468 // possible.
469 return IMPLICATION(merge_gemm_layer,
470 skip_dst_iter_copy() && (cell_position & last_iter)
471 && !(cell_position & first_layer));
472 }
473 bool is_brgemm;
474
475 diff_src_brgemm_conf_t diff_src_brgemm;
476 diff_wei_brgemm_conf_t diff_wei_brgemm;
477
478 dim_t M, N, K1, K2;
479
480 dim_t LDB1, LDB2;
481 dim_t LDA1[3];
482 dim_t LDA2[3];
483 dim_t LDC;
484
485 dim_t m_block, M_blocks;
486 dim_t n_block, N_blocks, n_tail;
487
488 dim_t k2_block, k1_block, k1_tail, k2_tail;
489 dim_t KB1_blocks, KB2_blocks;
490 dim_t K1padded, K2padded;
491
492 dim_t Kproj, Kprojpadded;
493 dim_t kproj_block, KBproj_blocks, kproj_tail;
494
495 dim_t Nproj, Nproj_blocks, nproj_tail;
496 dim_t LDAproj, LDBproj, LDCproj[4];
497 int dhc_block_peephole, dhc_tail_peephole, dhc_blocks_peephole;
498
499 dim_t nthr;
500 #if DNNL_X64
501 x64::cpu_isa_t brgemm_isa;
502 #endif
503 bool unfused_post_gemm;
504 };
505
506 bool is_ldigo(const memory_desc_wrapper &md);
507 bool is_ldgoi(const memory_desc_wrapper &md);
508 bool is_ldio(const memory_desc_wrapper &md);
509 bool is_ldoi(const memory_desc_wrapper &md);
510 bool is_ldigo_blocked(const memory_desc_wrapper &md);
511 bool is_ldgoi_blocked(const memory_desc_wrapper &md);
512 bool is_ldio_blocked(const memory_desc_wrapper &md);
513 bool is_ldoi_blocked(const memory_desc_wrapper &md);
514
515 int get_good_ld(int dim, int sizeof_dt);
516
517 template <typename T>
init_conf(rnn_conf_t & rnn,const rnn_desc_t & rd,const memory_desc_wrapper & src_layer_d,const memory_desc_wrapper & src_iter_d,const memory_desc_wrapper & src_iter_c_d,const memory_desc_wrapper & weights_layer_d,const memory_desc_wrapper & weights_iter_d,const memory_desc_wrapper & weights_projection_d,const memory_desc_wrapper & dst_layer_d,const memory_desc_wrapper & dst_iter_d,const memory_desc_wrapper & dst_iter_c_d)518 bool init_conf(rnn_conf_t &rnn, const rnn_desc_t &rd,
519 const memory_desc_wrapper &src_layer_d,
520 const memory_desc_wrapper &src_iter_d,
521 const memory_desc_wrapper &src_iter_c_d,
522 const memory_desc_wrapper &weights_layer_d,
523 const memory_desc_wrapper &weights_iter_d,
524 const memory_desc_wrapper &weights_projection_d,
525 const memory_desc_wrapper &dst_layer_d,
526 const memory_desc_wrapper &dst_iter_d,
527 const memory_desc_wrapper &dst_iter_c_d) {
528 rnn.is_fwd = utils::one_of(rd.prop_kind, prop_kind::forward_training,
529 prop_kind::forward_inference);
530 rnn.is_training = utils::one_of(
531 rd.prop_kind, prop_kind::forward_training, prop_kind::backward);
532 rnn.is_lbr = rd.cell_kind == dnnl_lbr_gru;
533 rnn.is_lstm_peephole = rd.cell_kind == dnnl_vanilla_lstm
534 && !memory_desc_wrapper(rd.weights_peephole_desc).is_zero();
535 rnn.is_lstm_projection = rd.cell_kind == dnnl_vanilla_lstm
536 && !memory_desc_wrapper(rd.weights_projection_desc).is_zero();
537
538 switch (rd.direction) {
539 case dnnl_unidirectional_left2right: rnn.exec_dir = l2r; break;
540 case dnnl_unidirectional_right2left: rnn.exec_dir = r2l; break;
541 case dnnl_bidirectional_concat: rnn.exec_dir = bi_concat; break;
542 case dnnl_bidirectional_sum: rnn.exec_dir = bi_sum; break;
543 default: break;
544 }
545
546 if (utils::everyone_is(data_type::f32, src_layer_d.data_type(),
547 dst_layer_d.data_type(), weights_layer_d.data_type()))
548 rnn.dt_conf = all_f32;
549 else if (utils::everyone_is(data_type::bf16, src_layer_d.data_type(),
550 dst_layer_d.data_type(), weights_layer_d.data_type())) {
551 if (!platform::has_data_type_support(data_type::bf16)) return false;
552 rnn.dt_conf = all_bf16;
553 } else if (dst_layer_d.data_type() == data_type::u8) {
554 if (IMPLICATION(
555 src_iter_d.md_, src_iter_d.data_type() == data_type::u8))
556 rnn.dt_conf = u8u8u8u8;
557 else
558 rnn.dt_conf = f32u8f32u8;
559 } else if (dst_layer_d.data_type() == data_type::s8) {
560 if (IMPLICATION(
561 src_iter_d.md_, src_iter_d.data_type() == data_type::s8))
562 rnn.dt_conf = s8s8s8s8;
563 else
564 rnn.dt_conf = f32s8f32s8;
565
566 } else if (dst_layer_d.data_type() == data_type::f32) {
567 if (IMPLICATION(
568 src_iter_d.md_, src_iter_d.data_type() == data_type::u8))
569 rnn.dt_conf = u8u8u8f32;
570 else if (IMPLICATION(src_iter_d.md_,
571 src_iter_d.data_type() == data_type::s8))
572 rnn.dt_conf = s8s8s8f32;
573 else if (IMPLICATION(src_layer_d.md_,
574 src_layer_d.data_type() == data_type::s8))
575 rnn.dt_conf = f32s8f32f32;
576 else
577 rnn.dt_conf = f32u8f32f32;
578 }
579
580 // Set problem members defining problem sizes
581 rnn.n_layer = weights_layer_d.dims()[0];
582 rnn.n_iter = src_layer_d.dims()[0];
583 rnn.n_dir = weights_layer_d.dims()[1];
584 rnn.n_gates = weights_layer_d.dims()[3];
585 rnn.n_states = rd.cell_kind == dnnl_vanilla_lstm ? 2 : 1;
586 rnn.n_bias = rnn.n_gates + rnn.is_lbr;
587 rnn.mb = src_layer_d.dims()[1];
588 rnn.sic = weights_iter_d.dims()[2];
589 rnn.slc = weights_layer_d.dims()[2];
590 rnn.dhc = weights_layer_d.dims()[4];
591 rnn.dlc = rnn.is_lstm_projection ? weights_projection_d.dims()[3] : rnn.dhc;
592 // All supported cells have dic == dlc
593 rnn.dic = rnn.dlc;
594
595 // set members with user memories leading dimensions
596 // Assumption: weights datatype size is the same as state datatype size
597 assert(types::data_type_size(weights_layer_d.data_type())
598 == types::data_type_size(src_layer_d.data_type()));
599
600 // set workspace leading dimensions (and non leading-dimensions)
601
602 // the ws and scratch proj_ht need to match as we use them interchangeably
603 assert(IMPLICATION(rnn.is_lstm_projection,
604 sizeof(typename T::ht_t) == sizeof(typename T::dst_iter_t)));
605 rnn.proj_ht_nld = rnn.mb;
606 rnn.proj_ht_ld = get_good_ld(rnn.dhc, sizeof(typename T::ht_t));
607
608 rnn.ws_gates_nld = rnn.mb;
609 rnn.ws_gates_ld
610 = get_good_ld(rnn.dhc * rnn.n_gates, sizeof(typename T::gates_t));
611 rnn.ws_ht_nld = rnn.proj_ht_nld;
612 rnn.ws_ht_ld = rnn.proj_ht_ld;
613
614 rnn.ws_states_layer_nld = rnn.mb;
615 static_assert(std::is_same<typename T::src_layer_t,
616 typename T::src_iter_t>::value,
617 "src_layer_t and src_iter_t must be the same");
618 rnn.ws_states_layer_ld
619 = get_good_ld(nstl::max(rnn.sic, nstl::max(rnn.slc, rnn.dlc)),
620 sizeof(typename T::src_layer_t));
621 // there is no need for al separate ws_states_iter for now as all
622 // supported cell have dst_iter == dst_layer
623 rnn.ws_states_iter_nld = rnn.ws_states_layer_nld;
624 rnn.ws_states_iter_ld = rnn.ws_states_layer_ld;
625
626 // we do not need a good ld for iter_c as it is not involved in GEMM
627 rnn.ws_states_iter_c_nld = rnn.mb;
628 rnn.ws_states_iter_c_ld = rnn.dhc;
629
630 // TODO: be more restrictive on the leading dimensions
631 rnn.ws_diff_states_layer_nld = rnn.mb;
632 rnn.ws_diff_states_layer_ld = get_good_ld(
633 nstl::max(nstl::max(rnn.slc, rnn.dic), nstl::max(rnn.sic, rnn.dhc)),
634 sizeof(typename T::gemm_acc_t));
635
636 rnn.ws_diff_states_iter_nld = rnn.mb;
637 rnn.ws_diff_states_iter_ld = get_good_ld(
638 nstl::max(nstl::max(rnn.slc, rnn.dic), nstl::max(rnn.sic, rnn.dhc)),
639 sizeof(typename T::gemm_acc_t));
640
641 rnn.ws_diff_states_iter_c_nld = rnn.mb;
642 rnn.ws_diff_states_iter_c_ld = rnn.dhc;
643
644 // set scratch (not)leading dimensions
645 // scratch gates is used to store intermediate gates before postgemm operation
646 // temporary: we also use it in lstmp as temporary scratchpad
647 // between projection and downconversion, hence the max with dlc
648 rnn.scratch_gates_nld = rnn.mb;
649 rnn.scratch_gates_ld
650 = get_good_ld(nstl::max(rnn.dlc, rnn.n_gates * rnn.dhc),
651 sizeof(typename T::scratch_t));
652 rnn.scratch_ht_nld = rnn.proj_ht_nld;
653 rnn.scratch_ht_ld = rnn.proj_ht_ld;
654
655 rnn.scratch_diff_ht_nld = rnn.mb;
656 rnn.scratch_diff_ht_ld
657 = get_good_ld(rnn.dlc, sizeof(typename T::gemm_acc_t));
658
659 // Assumption: {src,dst}_layer has tnc layout, {src,dst}_iter has ldnc,
660 rnn.src_layer_ld_ = src_layer_d.blocking_desc().strides[1];
661 rnn.dst_layer_ld_ = dst_layer_d.blocking_desc().strides[1];
662 rnn.src_iter_ld_ = types::is_zero_md(src_iter_d.md_)
663 ? 0
664 : src_iter_d.blocking_desc().strides[2];
665 rnn.dst_iter_ld_ = types::is_zero_md(dst_iter_d.md_)
666 ? 0
667 : dst_iter_d.blocking_desc().strides[2];
668 rnn.src_iter_c_ld_ = types::is_zero_md(src_iter_c_d.md_)
669 ? 0
670 : src_iter_c_d.blocking_desc().strides[2];
671 rnn.dst_iter_c_ld_ = types::is_zero_md(dst_iter_c_d.md_)
672 ? 0
673 : dst_iter_c_d.blocking_desc().strides[2];
674
675 /* Set the correct number of weights parts */
676 bool is_orig_gru = rd.cell_kind == alg_kind::vanilla_gru;
677 rnn.n_parts_weights_layer = 1;
678 rnn.parts_weights_layer[0] = rnn.n_gates;
679 rnn.parts_weights_layer[1] = 0;
680
681 rnn.n_parts_weights_iter = is_orig_gru ? 2 : 1;
682 rnn.parts_weights_iter[0] = is_orig_gru ? 2 : rnn.n_gates;
683 rnn.parts_weights_iter[1] = is_orig_gru ? 1 : 0;
684
685 rnn.n_parts_weights_projection = 1;
686 rnn.parts_weights_projection[0] = 1;
687
688 rnn.n_parts_bias = 1;
689 rnn.parts_bias[0] = rnn.n_bias;
690 rnn.parts_bias[1] = 0;
691
692 /* Decide wich gemm implementation to use: packed/nonpacked jit/cblas
693 * and if to mergre gemm across iterations */
694 bool is_f32 = rnn.dt_conf == all_f32, is_bf16 = rnn.dt_conf == all_bf16;
695 bool is_gru = utils::one_of(
696 rd.cell_kind, alg_kind::vanilla_gru, alg_kind::lbr_gru);
697 bool is_inference = !rnn.is_training;
698
699 // To be able to merge the GEMM on the layer input when not
700 // copying, we need to have a trivial stride for the T dimension
701 auto src_layer_is_trivial_stride = src_layer_d.blocking_desc().strides[0]
702 == (rnn.src_layer_ld_ * rnn.mb);
703 auto dst_layer_is_trivial_stride = dst_layer_d.blocking_desc().strides[0]
704 == (rnn.dst_layer_ld_ * rnn.mb);
705
706 rnn.merge_gemm_layer = (!rnn.is_brgemm)
707 ? ((rnn.is_fwd && src_layer_is_trivial_stride)
708 || ((rd.prop_kind == prop_kind::backward)
709 && dst_layer_is_trivial_stride))
710 && (((rnn.is_fwd && rnn.mb < 128) || !rnn.is_fwd)
711 || rnn.is_int8())
712 : false;
713 rnn.merge_gemm_iter = (!rnn.is_brgemm)
714 ? dst_layer_is_trivial_stride && !(rnn.is_fwd || is_gru)
715 : false;
716 rnn.force_nocopy = false;
717 #if DNNL_X64
718 rnn.force_nocopy = !x64::mayiuse(x64::avx512_mic) && x64::mayiuse(x64::avx)
719 && ((is_inference && (rnn.n_layer > 1 || rnn.mb < 100))
720 || (rnn.is_training && rnn.dhc < 500));
721 #endif
722
723 /* Decide to copy bias */
724 rnn.copy_bias = rnn.is_int8();
725
726 rnn.use_layer_packed_gemm = !rnn.is_brgemm
727 ? utils::one_of(weights_layer_d.format_kind(), format_kind::any,
728 format_kind::rnn_packed)
729 && is_inference
730 && ((is_f32 && pack_sgemm_supported() && rnn.n_iter == 1)
731 || rnn.is_int8() || is_bf16)
732 : false;
733 rnn.use_iter_packed_gemm = !rnn.is_brgemm
734 ? utils::one_of(weights_iter_d.format_kind(), format_kind::any,
735 format_kind::rnn_packed)
736 && is_inference
737 && ((is_f32 && pack_sgemm_supported() && rnn.mb >= 16)
738 || rnn.is_int8() || is_bf16)
739 : false;
740 rnn.use_projection_packed_gemm = !rnn.is_brgemm
741 ? utils::one_of(weights_projection_d.format_kind(),
742 format_kind::any, format_kind::rnn_packed)
743 && is_inference
744 && ((is_f32 && pack_sgemm_supported() && rnn.n_iter == 1)
745 || rnn.is_int8() || is_bf16)
746 : false;
747
748 /* Set packed gemm sizes */
749 /* TODO: investigate the benefit of mixing packed and non-packed weights parts */
750 auto set_pack_sizes
751 = [&](bool merge, bool &do_pack, size_t &weights_pack_size,
752 int &n_parts, int *parts, size_t *parts_pack_size,
753 size_t &comp_offset, int ic, int oc, int weights_oc,
754 dim_t data_ld) -> bool {
755 bool pack = true;
756 weights_pack_size = 0;
757 for (int p = 0; p < n_parts; p++) {
758 dim_t m_p = rnn.is_fwd ? (parts[p] * oc) : ic;
759 dim_t k_p = rnn.is_fwd ? ic : (parts[p] * oc);
760 dim_t n_p = merge ? rnn.mb * rnn.n_iter : rnn.mb;
761 bool pack_part = true;
762
763 dnnl_status_t st = dnnl_success;
764 switch (rnn.dt_conf) {
765 case all_f32:
766 st = sgemm_pack_get_size("A", "N", "N", &m_p, &n_p, &k_p,
767 &m_p, &data_ld, &parts_pack_size[p], &pack_part);
768 break;
769 case s8s8s8f32:
770 case f32s8f32f32:
771 case s8s8s8s8:
772 case f32s8f32s8:
773 st = gemm_s8u8s32_pack_get_size("A", "N", "N", &m_p, &n_p,
774 &k_p, &m_p, &data_ld, &parts_pack_size[p],
775 &pack_part);
776 break;
777 case u8u8u8f32:
778 case f32u8f32f32:
779 case u8u8u8u8:
780 case f32u8f32u8:
781 st = gemm_s8u8s32_pack_get_size("A", "N", "N", &m_p, &n_p,
782 &k_p, &m_p, &data_ld, &parts_pack_size[p],
783 &pack_part);
784 break;
785 case all_bf16:
786 st = gemm_bf16bf16f32_pack_get_size("A", "N", "N", &m_p,
787 &n_p, &k_p, &m_p, &data_ld, &parts_pack_size[p],
788 &pack_part);
789 break;
790 default: assert(!"Unsupported configuration");
791 }
792 if (st != dnnl_success) return false;
793
794 pack = pack && pack_part;
795 weights_pack_size += rnn.n_layer * rnn.n_dir * parts_pack_size[p];
796 }
797
798 // NOTE: pack is updated only for f32. We force pack for int8
799 do_pack = (rnn.dt_conf == all_f32) ? pack : true;
800 comp_offset = weights_pack_size;
801 const bool need_compensation = rnn.is_int8();
802 weights_pack_size += (need_compensation ? rnn.n_layer * rnn.n_dir : 0)
803 * weights_oc * sizeof(float);
804
805 return true;
806 };
807 // TODO: the activation leading dimension can vary for first layer/iteration
808 if (rnn.use_layer_packed_gemm) {
809 bool ok = set_pack_sizes(rnn.merge_gemm_layer,
810 rnn.use_layer_packed_gemm, rnn.weights_layer_pack_size,
811 rnn.n_parts_weights_layer, rnn.parts_weights_layer,
812 rnn.part_weights_layer_pack_size, rnn.weights_layer_comp_offset,
813 rnn.slc, rnn.dhc, rnn.n_gates * rnn.dhc,
814 rnn.ws_states_layer_ld);
815 if (!ok) return false;
816 }
817
818 if (rnn.use_iter_packed_gemm) {
819 bool ok = set_pack_sizes(rnn.merge_gemm_iter, rnn.use_iter_packed_gemm,
820 rnn.weights_iter_pack_size, rnn.n_parts_weights_iter,
821 rnn.parts_weights_iter, rnn.part_weights_iter_pack_size,
822 rnn.weights_iter_comp_offset, rnn.sic, rnn.dhc,
823 rnn.n_gates * rnn.dhc, rnn.ws_states_iter_ld);
824 if (!ok) return false;
825 }
826
827 if (rnn.use_projection_packed_gemm) {
828 bool ok = set_pack_sizes(false, rnn.use_projection_packed_gemm,
829 rnn.weights_projection_pack_size,
830 rnn.n_parts_weights_projection, rnn.parts_weights_projection,
831 rnn.part_weights_projection_pack_size,
832 rnn.weights_projection_comp_offset, rnn.dhc, rnn.dic, rnn.dic,
833 rnn.scratch_ht_ld);
834 if (!ok) return false;
835 }
836
837 return true;
838 }
839
840 template <typename T>
set_conf(rnn_conf_t & rnn,const rnn_desc_t & rd,const memory_desc_wrapper & weights_layer_d,const memory_desc_wrapper & weights_iter_d,const memory_desc_wrapper & weights_projection_d,const memory_desc_wrapper & diff_weights_layer_d,const memory_desc_wrapper & diff_weights_iter_d,const memory_desc_wrapper & diff_weights_projection_d)841 void set_conf(rnn_conf_t &rnn, const rnn_desc_t &rd,
842 const memory_desc_wrapper &weights_layer_d,
843 const memory_desc_wrapper &weights_iter_d,
844 const memory_desc_wrapper &weights_projection_d,
845 const memory_desc_wrapper &diff_weights_layer_d,
846 const memory_desc_wrapper &diff_weights_iter_d,
847 const memory_desc_wrapper &diff_weights_projection_d) {
848
849 // Set leading dimensions for input weights arrays depending on input format
850 auto set_dims = [&](const memory_desc_wrapper &md, int &ld, int &nld) {
851 ld = 0;
852 nld = 0;
853 if (md.is_blocking_desc()) {
854 if (is_ldigo(md)) {
855 ld = (int)md.blocking_desc().strides[2];
856 nld = md.dims()[2];
857 } else if (is_ldgoi(md)) {
858 ld = (int)md.blocking_desc().strides[4];
859 nld = md.dims()[3] * md.dims()[4];
860 } else if (is_ldoi(md)) {
861 ld = (int)md.blocking_desc().strides[3];
862 nld = md.dims()[3];
863 } else if (is_ldio(md)) {
864 ld = (int)md.blocking_desc().strides[2];
865 nld = md.dims()[2];
866 } else
867 assert(!"unsupported weights format");
868 }
869 };
870 set_dims(weights_layer_d, rnn.weights_layer_ld, rnn.weights_layer_nld);
871 set_dims(weights_iter_d, rnn.weights_iter_ld, rnn.weights_iter_nld);
872 set_dims(weights_projection_d, rnn.weights_projection_ld,
873 rnn.weights_projection_nld);
874 if (!rnn.is_fwd) {
875 set_dims(diff_weights_layer_d, rnn.diff_weights_layer_ld,
876 rnn.diff_weights_layer_nld);
877 set_dims(diff_weights_iter_d, rnn.diff_weights_iter_ld,
878 rnn.diff_weights_iter_nld);
879 set_dims(diff_weights_projection_d, rnn.diff_weights_projection_ld,
880 rnn.diff_weights_projection_nld);
881 }
882
883 assert(weights_layer_d.data_type() == weights_iter_d.data_type());
884 assert(IMPLICATION(diff_weights_layer_d.ndims() != 0,
885 (diff_weights_layer_d.data_type()
886 == diff_weights_iter_d.data_type())));
887
888 /* Set workspace sizes to store:
889 * states to compute a pass
890 * diff states to compute bwd pass (training onl)y
891 * intermediate results from the gates
892 */
893
894 assert(sizeof(typename T::src_layer_t) == sizeof(typename T::dst_layer_t));
895 assert(sizeof(typename T::src_iter_t) == sizeof(typename T::dst_iter_t));
896
897 rnn.use_workspace = rnn.is_training;
898 // TODO: for inference, we can make ws_states_* smaller, but
899 // dependant of the grid execution though
900 rnn.ws_states_layer_size = (size_t)(rnn.n_layer + 1) * rnn.n_dir
901 * (rnn.n_iter + 1) * rnn.mb * rnn.ws_states_layer_ld
902 * sizeof(typename T::src_layer_t);
903 rnn.ws_states_iter_size = (size_t)(rnn.n_layer + 1) * rnn.n_dir
904 * (rnn.n_iter + 1) * rnn.mb * rnn.ws_states_iter_ld
905 * sizeof(typename T::src_iter_t);
906 bool is_lstm = rd.cell_kind == dnnl_vanilla_lstm;
907 rnn.ws_states_iter_c_size = is_lstm
908 ? (size_t)(rnn.n_layer + 1) * rnn.n_dir * (rnn.n_iter + 1) * rnn.mb
909 * rnn.ws_states_iter_c_ld * sizeof(float)
910 : 0;
911
912 rnn.ws_diff_states_layer_size = rnn.is_training
913 ? (size_t)(rnn.n_layer + 1) * rnn.n_dir * (rnn.n_iter + 1) * rnn.mb
914 * rnn.ws_diff_states_layer_ld
915 * sizeof(typename T::gemm_acc_t)
916 : (size_t)0;
917 rnn.ws_diff_states_iter_size = rnn.is_training
918 ? (size_t)(rnn.n_layer + 1) * rnn.n_dir * (rnn.n_iter + 1) * rnn.mb
919 * rnn.ws_diff_states_iter_ld
920 * sizeof(typename T::gemm_acc_t)
921 : (size_t)0;
922 rnn.ws_diff_states_iter_c_size = rnn.is_training && is_lstm
923 ? (size_t)(rnn.n_layer + 1) * rnn.n_dir * (rnn.n_iter + 1) * rnn.mb
924 * rnn.ws_diff_states_iter_c_ld
925 * sizeof(typename T::gemm_acc_t)
926 : (size_t)0;
927
928 rnn.ws_gates_size = rnn.is_training
929 ? (size_t)rnn.n_layer * rnn.n_dir * rnn.n_iter * rnn.ws_gates_nld
930 * rnn.ws_gates_ld * sizeof(typename T::gates_t)
931 : (size_t)0;
932 rnn.ws_ht_size = rnn.is_training
933 ? (size_t)rnn.n_layer * rnn.n_dir * rnn.n_iter * rnn.ws_ht_nld
934 * rnn.ws_ht_ld * sizeof(typename T::dst_iter_t)
935 : (size_t)0;
936 rnn.n_iter_scratch_gates
937 = (rnn.merge_gemm_layer || rnn.merge_gemm_iter) ? rnn.n_iter : 1;
938 rnn.scratch_gates_size = rnn.n_iter_scratch_gates * rnn.scratch_gates_nld
939 * rnn.scratch_gates_ld * sizeof(typename T::scratch_t);
940 rnn.scratch_ht_size
941 = rnn.scratch_ht_nld * rnn.scratch_ht_ld * sizeof(typename T::ht_t);
942 rnn.scratch_diff_ht_size = rnn.is_training ? rnn.scratch_diff_ht_nld
943 * rnn.scratch_diff_ht_ld * sizeof(typename T::gemm_acc_t)
944 : (size_t)0;
945
946 /* set other sizes */
947 /// scratchpad buffer for each cell to hold intermediate data in gru/lbr_gru
948 rnn.scratch_cell_size = rnn.is_lbr
949 ? (size_t)rnn.scratch_gates_nld * rnn.scratch_gates_ld
950 * sizeof(typename T::gemm_acc_t)
951 : (rd.cell_kind == alg_kind::vanilla_gru
952 ? (size_t)rnn.ws_states_layer_nld
953 * rnn.ws_states_layer_ld
954 * sizeof(typename T::gemm_acc_t)
955 : 0);
956 /// workspace needed for lbr GRU
957 rnn.ws_per_cell = (size_t)rnn.is_lbr * rnn.mb * rnn.dhc
958 * sizeof(typename T::gemm_acc_t);
959 rnn.ws_grid_comp_size = (size_t)rnn.is_lbr * rnn.is_training * rnn.n_layer
960 * rnn.n_dir * rnn.n_iter * rnn.ws_per_cell * sizeof(float);
961 /// bias ws needed to add compensation in int8
962 rnn.ws_bias_size = (size_t)rnn.n_layer * rnn.n_dir * rnn.n_bias * rnn.dhc
963 * sizeof(float);
964 }
965
966 void set_offsets(const rnn_conf_t &rnn, size_t &ws_gates_offset,
967 size_t &ws_ht_offset, size_t &ws_state_layer_offset,
968 size_t &ws_states_iter_offset, size_t &ws_states_iter_c_offset,
969 size_t &ws_diff_states_layer_offset, size_t &ws_diff_states_iter_offset,
970 size_t &ws_diff_states_iter_c_offset, size_t &ws_grid_comp_offset,
971 size_t &ws_bias_offset, size_t &scratch_gates_offset,
972 size_t &scratch_ht_offset, size_t &scratch_diff_ht_offset,
973 size_t &scratch_cell_offset, size_t &scratchpad_size,
974 size_t &workspace_size);
975
976 void get_scratchpad_and_workspace_sizes(
977 const rnn_conf_t &rnn, size_t &scratchpad_size, size_t &workspace_size);
978 status_t set_expected_desc(rnn_conf_t &rnn, memory_desc_t &weights_md,
979 weights_type_t weights_type);
980 status_t set_good_strides(memory_desc_t &weights_md, format_tag_t tag);
981
982 template <typename T>
983 struct ws_gates_aoc {
ws_gates_aocdnnl::impl::cpu::rnn_utils::ws_gates_aoc984 ws_gates_aoc(const rnn_conf_t &rnn, T *data)
985 : gates_(data, rnn.ws_gates_nld, rnn.ws_gates_ld), DHC_(rnn.dhc) {}
operator ()dnnl::impl::cpu::rnn_utils::ws_gates_aoc986 T &operator()(int batch, int gate, int dhc) const {
987 return gates_(batch, gate * DHC_ + dhc);
988 }
989
990 private:
991 const dnnl::impl::utils::array_offset_calculator<T, 2> gates_;
992 const int DHC_;
993 };
994 using ws_gates_aoc_t = ws_gates_aoc<float>;
995 using ws_gates_aoc_s32_t = ws_gates_aoc<int32_t>;
996
997 template <typename T>
998 struct ws_ht_aoc {
ws_ht_aocdnnl::impl::cpu::rnn_utils::ws_ht_aoc999 ws_ht_aoc(const rnn_conf_t &rnn, T *data)
1000 : ht_(data, rnn.ws_ht_nld, rnn.ws_ht_ld) {}
operator ()dnnl::impl::cpu::rnn_utils::ws_ht_aoc1001 T &operator()(int batch, int dhc) const { return ht_(batch, dhc); }
1002
1003 private:
1004 const dnnl::impl::utils::array_offset_calculator<T, 2> ht_;
1005 };
1006
1007 template <typename T>
1008 struct scratch_gates_aoc {
scratch_gates_aocdnnl::impl::cpu::rnn_utils::scratch_gates_aoc1009 scratch_gates_aoc(const rnn_conf_t &rnn, T *data)
1010 : gates_(data, rnn.scratch_gates_nld, rnn.scratch_gates_ld)
1011 , DHC_(rnn.dhc) {}
operator ()dnnl::impl::cpu::rnn_utils::scratch_gates_aoc1012 T &operator()(int batch, int gate, int dhc) const {
1013 return gates_(batch, gate * DHC_ + dhc);
1014 }
1015
1016 private:
1017 const dnnl::impl::utils::array_offset_calculator<T, 2> gates_;
1018 const int DHC_;
1019 };
1020 using scratch_gates_aoc_t = scratch_gates_aoc<float>;
1021 using scratch_gates_aoc_s32_t = scratch_gates_aoc<int32_t>;
1022
1023 template <typename T>
1024 struct scratch_ht_aoc {
scratch_ht_aocdnnl::impl::cpu::rnn_utils::scratch_ht_aoc1025 scratch_ht_aoc(const rnn_conf_t &rnn, T *data)
1026 : ht_(data, rnn.scratch_ht_nld, rnn.scratch_ht_ld) {}
operator ()dnnl::impl::cpu::rnn_utils::scratch_ht_aoc1027 T &operator()(int batch, int dhc) const { return ht_(batch, dhc); }
1028
1029 private:
1030 const dnnl::impl::utils::array_offset_calculator<T, 2> ht_;
1031 };
1032 using scratch_ht_aoc_t = scratch_ht_aoc<float>;
1033 using scratch_ht_aoc_s32_t = scratch_ht_aoc<int32_t>;
1034
1035 template <typename T>
1036 struct weights_peephole_aoc_t {
weights_peephole_aoc_tdnnl::impl::cpu::rnn_utils::weights_peephole_aoc_t1037 weights_peephole_aoc_t(const rnn_conf_t &rnn, T *data)
1038 : weights_peephole_(data, 3, rnn.dhc) {}
operator ()dnnl::impl::cpu::rnn_utils::weights_peephole_aoc_t1039 T &operator()(int g, int dhc) const { return weights_peephole_(g, dhc); }
1040
1041 private:
1042 const utils::array_offset_calculator<T, 2> weights_peephole_;
1043 };
1044
1045 struct bias_aoc_t {
bias_aoc_tdnnl::impl::cpu::rnn_utils::bias_aoc_t1046 bias_aoc_t(const rnn_conf_t &rnn, const float *data)
1047 : bias_(data, rnn.n_bias, rnn.dhc) {}
operator ()dnnl::impl::cpu::rnn_utils::bias_aoc_t1048 const float &operator()(int bias_n, int dhc) const {
1049 return bias_(bias_n, dhc);
1050 }
1051
1052 private:
1053 const dnnl::impl::utils::array_offset_calculator<const float, 2> bias_;
1054 };
1055
1056 template <typename T>
1057 struct ws_states_layer_aoc {
ws_states_layer_aocdnnl::impl::cpu::rnn_utils::ws_states_layer_aoc1058 ws_states_layer_aoc(const rnn_conf_t &rnn, T *data, int leading_dim)
1059 : state_(data, rnn.ws_states_layer_nld, leading_dim) {}
ws_states_layer_aocdnnl::impl::cpu::rnn_utils::ws_states_layer_aoc1060 ws_states_layer_aoc(const rnn_conf_t &rnn, T *data)
1061 : state_(data, rnn.ws_states_layer_nld, rnn.ws_states_layer_ld) {}
operator ()dnnl::impl::cpu::rnn_utils::ws_states_layer_aoc1062 T &operator()(int batch, int dhc) const { return state_(batch, dhc); }
1063
1064 private:
1065 const dnnl::impl::utils::array_offset_calculator<T, 2> state_;
1066 };
1067
1068 template <typename T>
1069 struct ws_states_iter_aoc {
ws_states_iter_aocdnnl::impl::cpu::rnn_utils::ws_states_iter_aoc1070 ws_states_iter_aoc(const rnn_conf_t &rnn, T *data, int leading_dim)
1071 : state_(data, rnn.ws_states_iter_nld, leading_dim) {}
ws_states_iter_aocdnnl::impl::cpu::rnn_utils::ws_states_iter_aoc1072 ws_states_iter_aoc(const rnn_conf_t &rnn, T *data)
1073 : state_(data, rnn.ws_states_iter_nld, rnn.ws_states_iter_ld) {}
operator ()dnnl::impl::cpu::rnn_utils::ws_states_iter_aoc1074 T &operator()(int batch, int dhc) const { return state_(batch, dhc); }
1075
1076 private:
1077 const dnnl::impl::utils::array_offset_calculator<T, 2> state_;
1078 };
1079
1080 template <typename T>
1081 struct ws_states_iter_c_aoc {
ws_states_iter_c_aocdnnl::impl::cpu::rnn_utils::ws_states_iter_c_aoc1082 ws_states_iter_c_aoc(const rnn_conf_t &rnn, T *data, int leading_dim)
1083 : state_(data, rnn.ws_states_iter_c_nld, leading_dim) {}
ws_states_iter_c_aocdnnl::impl::cpu::rnn_utils::ws_states_iter_c_aoc1084 ws_states_iter_c_aoc(const rnn_conf_t &rnn, T *data)
1085 : state_(data, rnn.ws_states_iter_c_nld, rnn.ws_states_iter_c_ld) {}
operator ()dnnl::impl::cpu::rnn_utils::ws_states_iter_c_aoc1086 T &operator()(int batch, int dhc) const { return state_(batch, dhc); }
1087
1088 private:
1089 const dnnl::impl::utils::array_offset_calculator<T, 2> state_;
1090 };
1091
1092 template <typename T>
1093 struct ws_diff_states_layer_aoc {
ws_diff_states_layer_aocdnnl::impl::cpu::rnn_utils::ws_diff_states_layer_aoc1094 ws_diff_states_layer_aoc(const rnn_conf_t &rnn, T *data)
1095 : diff_states_layer_(data, rnn.ws_diff_states_layer_nld,
1096 rnn.ws_diff_states_layer_ld) {}
operator ()dnnl::impl::cpu::rnn_utils::ws_diff_states_layer_aoc1097 T &operator()(int batch, int dhc) const {
1098 return diff_states_layer_(batch, dhc);
1099 }
1100
1101 private:
1102 const dnnl::impl::utils::array_offset_calculator<T, 2> diff_states_layer_;
1103 };
1104
1105 template <typename T>
1106 struct ws_diff_states_iter_aoc {
ws_diff_states_iter_aocdnnl::impl::cpu::rnn_utils::ws_diff_states_iter_aoc1107 ws_diff_states_iter_aoc(const rnn_conf_t &rnn, T *data)
1108 : diff_states_iter_(data, rnn.ws_diff_states_iter_nld,
1109 rnn.ws_diff_states_iter_ld) {}
operator ()dnnl::impl::cpu::rnn_utils::ws_diff_states_iter_aoc1110 T &operator()(int batch, int dhc) const {
1111 return diff_states_iter_(batch, dhc);
1112 }
1113
1114 private:
1115 const dnnl::impl::utils::array_offset_calculator<T, 2> diff_states_iter_;
1116 };
1117
1118 template <typename T>
1119 struct ws_diff_states_iter_c_aoc {
ws_diff_states_iter_c_aocdnnl::impl::cpu::rnn_utils::ws_diff_states_iter_c_aoc1120 ws_diff_states_iter_c_aoc(const rnn_conf_t &rnn, T *data)
1121 : diff_states_iter_c_(data, rnn.ws_diff_states_iter_c_nld,
1122 rnn.ws_diff_states_iter_c_ld) {}
operator ()dnnl::impl::cpu::rnn_utils::ws_diff_states_iter_c_aoc1123 T &operator()(int batch, int dhc) const {
1124 return diff_states_iter_c_(batch, dhc);
1125 }
1126
1127 private:
1128 const dnnl::impl::utils::array_offset_calculator<T, 2> diff_states_iter_c_;
1129 };
1130
1131 struct ws_diff_w_iter_aoc_t {
ws_diff_w_iter_aoc_tdnnl::impl::cpu::rnn_utils::ws_diff_w_iter_aoc_t1132 ws_diff_w_iter_aoc_t(const rnn_conf_t &rnn, float *data)
1133 : diff_weights_iter_(
1134 data, rnn.diff_weights_iter_nld, rnn.diff_weights_iter_ld)
1135 , DHC_(rnn.dhc) {}
operator ()dnnl::impl::cpu::rnn_utils::ws_diff_w_iter_aoc_t1136 float &operator()(int sic, int gate, int dhc) const {
1137 return diff_weights_iter_(sic, gate * DHC_ + dhc);
1138 }
1139
1140 private:
1141 const dnnl::impl::utils::array_offset_calculator<float, 2>
1142 diff_weights_iter_;
1143 const int DHC_;
1144 };
1145
1146 } // namespace rnn_utils
1147 } // namespace cpu
1148 } // namespace impl
1149 } // namespace dnnl
1150 #endif
1151