1 /*******************************************************************************
2 * Copyright 2018-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include <initializer_list>
18 
19 #include "oneapi/dnnl/dnnl.h"
20 
21 #include "c_types_map.hpp"
22 #include "type_helpers.hpp"
23 #include "utils.hpp"
24 
25 namespace dnnl {
26 namespace impl {
27 namespace rnn {
28 
get_gates_count(dnnl_alg_kind_t cell_kind)29 int get_gates_count(dnnl_alg_kind_t cell_kind) {
30     switch (cell_kind) {
31         case dnnl::impl::alg_kind::vanilla_rnn: return 1;
32         case dnnl::impl::alg_kind::vanilla_gru: return 3;
33         case dnnl::impl::alg_kind::lbr_gru: return 3;
34         case dnnl::impl::alg_kind::vanilla_lstm: return 4;
35         default: assert(!"unknown cell kind"); return 0;
36     }
37     return 0;
38 }
39 
40 } // namespace rnn
41 } // namespace impl
42 } // namespace dnnl
43 
44 namespace {
45 using namespace dnnl::impl;
46 using namespace dnnl::impl::status;
47 using namespace dnnl::impl::types;
48 using namespace dnnl::impl::utils;
49 
maybe_init_md(memory_desc_t & md,const memory_desc_t * with_md)50 void maybe_init_md(memory_desc_t &md, const memory_desc_t *with_md) {
51     if (with_md) md = *with_md;
52 }
53 
xnor_md(const memory_desc_t * a_md,const memory_desc_t * b_md)54 bool xnor_md(const memory_desc_t *a_md, const memory_desc_t *b_md) {
55     return is_zero_md(a_md) == is_zero_md(b_md);
56 }
57 
check_runtime_dims_or_strides(std::initializer_list<const memory_desc_t * > l)58 status_t check_runtime_dims_or_strides(
59         std::initializer_list<const memory_desc_t *> l) {
60     bool runtime_dims_or_strides = false;
61     for (auto md : l)
62         runtime_dims_or_strides = runtime_dims_or_strides
63                 || memory_desc_wrapper(md).has_runtime_dims_or_strides();
64     return runtime_dims_or_strides ? unimplemented : success;
65 }
66 
67 template <typename... DTs>
expect_dt(const memory_desc_t & md,DTs...dts)68 bool expect_dt(const memory_desc_t &md, DTs... dts) {
69     return IMPLICATION(!is_zero_md(&md), utils::one_of(md.data_type, dts...));
70 }
71 
expect_dims(const memory_desc_t & md,std::initializer_list<dim_t> dims,bool allow_zero=true)72 status_t expect_dims(const memory_desc_t &md, std::initializer_list<dim_t> dims,
73         bool allow_zero = true) {
74     if (is_zero_md(&md))
75         return (allow_zero || dims.size() == 0) ? success : invalid_arguments;
76 
77     if (md.ndims != (int)dims.size()) return invalid_arguments;
78 
79     int d_in_md = 0;
80     for (auto d : dims)
81         if (d != md.dims[d_in_md++]) return invalid_arguments;
82 
83     return success;
84 }
85 
check_data_type_consistency_fwd(const rnn_desc_t & r)86 status_t check_data_type_consistency_fwd(const rnn_desc_t &r) {
87     using namespace data_type;
88     data_type_t src_layer_dt = r.src_layer_desc.data_type;
89     data_type_t dst_layer_dt = r.dst_layer_desc.data_type;
90     data_type_t weights_iter_dt = r.weights_iter_desc.data_type;
91     data_type_t weights_layer_dt = r.weights_layer_desc.data_type;
92     data_type_t weights_projection_dt = r.weights_projection_desc.data_type;
93 
94     const bool is_forward = !(r.prop_kind == prop_kind::backward);
95     const bool is_inference = r.prop_kind == prop_kind::forward_inference;
96     const bool is_int8_ok
97             = one_of(r.cell_kind, dnnl_vanilla_lstm, dnnl_vanilla_gru);
98 
99     const bool cell_state_check = expect_dt(r.src_iter_c_desc, f32, f16)
100             && expect_dt(r.dst_iter_c_desc, f32, f16);
101 
102     const bool is_f32 = everyone_is(f32, src_layer_dt, dst_layer_dt,
103                                 weights_iter_dt, weights_layer_dt)
104             && expect_dt(r.src_iter_desc, f32)
105             && expect_dt(r.weights_peephole_desc, f32)
106             && expect_dt(r.weights_projection_desc, f32)
107             && expect_dt(r.dst_iter_desc, f32) && expect_dt(r.bias_desc, f32);
108 
109     const bool is_bf16 = everyone_is(bf16, src_layer_dt, dst_layer_dt,
110                                  weights_iter_dt, weights_layer_dt)
111             && expect_dt(r.src_iter_desc, bf16)
112             && expect_dt(r.weights_peephole_desc, f32)
113             && one_of(weights_projection_dt, bf16, data_type::undef)
114             && expect_dt(r.dst_iter_desc, bf16) && expect_dt(r.bias_desc, f32);
115 
116     const bool is_f16 = is_forward
117             && everyone_is(f16, src_layer_dt, dst_layer_dt, weights_iter_dt,
118                     weights_layer_dt)
119             && expect_dt(r.src_iter_desc, f16)
120             && expect_dt(r.weights_peephole_desc, f16)
121             && r.weights_peephole_desc.data_type == data_type::undef
122             && expect_dt(r.dst_iter_desc, f16) && expect_dt(r.bias_desc, f16);
123 
124     const bool is_u8u8u8 = is_inference && is_int8_ok && src_layer_dt == u8
125             && one_of(dst_layer_dt, u8, f32)
126             && everyone_is(s8, weights_iter_dt, weights_layer_dt)
127             && expect_dt(r.src_iter_desc, u8)
128             && expect_dt(r.src_iter_c_desc, f32)
129             && r.weights_peephole_desc.data_type == data_type::undef
130             && one_of(weights_projection_dt, s8, data_type::undef)
131             && expect_dt(r.dst_iter_desc, u8)
132             && expect_dt(r.dst_iter_c_desc, f32) && expect_dt(r.bias_desc, f32);
133 
134     const bool is_f32u8f32 = is_inference && is_int8_ok && src_layer_dt == u8
135             && everyone_is(s8, weights_iter_dt, weights_layer_dt)
136             && r.weights_peephole_desc.data_type == data_type::undef
137             && one_of(weights_projection_dt, s8, data_type::undef)
138             && one_of(dst_layer_dt, u8, f32) && expect_dt(r.src_iter_desc, f32)
139             && expect_dt(r.dst_iter_desc, f32) && expect_dt(r.bias_desc, f32);
140 
141     const bool is_s8s8s8 = is_inference && is_int8_ok && src_layer_dt == s8
142             && one_of(dst_layer_dt, s8, f32)
143             && everyone_is(s8, weights_iter_dt, weights_layer_dt)
144             && expect_dt(r.src_iter_desc, s8)
145             && expect_dt(r.src_iter_c_desc, f32)
146             && r.weights_peephole_desc.data_type == data_type::undef
147             && one_of(weights_projection_dt, s8, data_type::undef)
148             && expect_dt(r.dst_iter_desc, s8)
149             && expect_dt(r.dst_iter_c_desc, f32) && expect_dt(r.bias_desc, f32);
150 
151     const bool is_f32s8f32 = is_inference && is_int8_ok && src_layer_dt == s8
152             && everyone_is(s8, weights_iter_dt, weights_layer_dt)
153             && r.weights_peephole_desc.data_type == data_type::undef
154             && one_of(weights_projection_dt, s8, data_type::undef)
155             && one_of(dst_layer_dt, s8, f32) && expect_dt(r.src_iter_desc, f32)
156             && expect_dt(r.dst_iter_desc, f32) && expect_dt(r.bias_desc, f32);
157 
158     return cell_state_check
159                     && (is_f32 || is_bf16 || is_f16 || is_u8u8u8 || is_f32u8f32
160                             || is_s8s8s8 || is_f32s8f32)
161             ? success
162             : unimplemented;
163 }
164 
check_data_type_consistency_bwd(const rnn_desc_t & r)165 status_t check_data_type_consistency_bwd(const rnn_desc_t &r) {
166     using namespace data_type;
167 
168     /* We require diffs to be f32, even for bf16 */
169     bool are_diff_f32 = everyone_is(f32, r.diff_src_layer_desc.data_type,
170                                 r.diff_dst_layer_desc.data_type,
171                                 r.diff_weights_iter_desc.data_type,
172                                 r.diff_weights_layer_desc.data_type)
173             && expect_dt(r.diff_src_iter_desc, f32)
174             && expect_dt(r.diff_dst_iter_desc, f32)
175             && expect_dt(r.diff_weights_peephole_desc, f32)
176             && expect_dt(r.diff_weights_projection_desc, f32)
177             && expect_dt(r.diff_bias_desc, f32)
178             && expect_dt(r.diff_src_iter_c_desc, f32)
179             && expect_dt(r.diff_dst_iter_c_desc, f32);
180 
181     return are_diff_f32 ? success : unimplemented;
182 }
183 
check_dim_consistency(const rnn_desc_t & r)184 status_t check_dim_consistency(const rnn_desc_t &r) {
185     const bool is_lstm_projection = r.cell_kind == dnnl_vanilla_lstm
186             && !is_zero_md(&r.weights_projection_desc);
187 
188     const dim_t L = r.weights_layer_desc.dims[0];
189     const dim_t T = r.src_layer_desc.dims[0];
190     const dim_t N = r.src_layer_desc.dims[1];
191     const dim_t D = one_of(r.direction, dnnl_unidirectional_left2right,
192                             dnnl_unidirectional_right2left)
193             ? 1
194             : 2;
195     const dim_t G = rnn::get_gates_count(r.cell_kind);
196     const dim_t SLC = r.src_layer_desc.dims[2];
197     const dim_t SIC = r.weights_iter_desc.dims[2];
198     const dim_t DLC = r.dst_layer_desc.dims[2];
199     const dim_t DHC = r.weights_layer_desc.dims[4];
200     const dim_t DIC
201             = is_lstm_projection ? r.weights_projection_desc.dims[3] : DHC;
202 
203     const bool extra_bias = r.cell_kind == alg_kind::lbr_gru;
204     const dim_t dlc_multiplier
205             = (r.direction == dnnl_bidirectional_concat) ? 2 : 1;
206 
207     bool args_ok = IMPLICATION(utils::one_of(r.cell_kind, alg_kind::vanilla_gru,
208                                        alg_kind::lbr_gru),
209                            SIC == DHC)
210             && dlc_multiplier * DIC == DLC
211             && IMPLICATION(L > 1, dlc_multiplier * SLC == DLC)
212             && IMPLICATION(T > 1, SIC == DIC);
213     if (!args_ok) return invalid_arguments;
214 
215     CHECK(expect_dims(r.src_layer_desc, {T, N, SLC}, false));
216     CHECK(expect_dims(r.src_iter_desc, {L, D, N, SIC}));
217     CHECK(expect_dims(r.src_iter_c_desc, {L, D, N, DHC}));
218     CHECK(expect_dims(r.weights_layer_desc, {L, D, SLC, G, DHC}, false));
219     CHECK(expect_dims(r.weights_iter_desc, {L, D, SIC, G, DHC}, false));
220     CHECK(expect_dims(r.weights_peephole_desc, {L, D, 3, DHC}));
221     CHECK(expect_dims(r.weights_projection_desc, {L, D, DHC, DIC}));
222     CHECK(expect_dims(r.bias_desc, {L, D, G + extra_bias, DHC}));
223     CHECK(expect_dims(r.dst_layer_desc, {T, N, DLC}, false));
224     CHECK(expect_dims(r.dst_iter_desc, {L, D, N, DIC}));
225     CHECK(expect_dims(r.dst_iter_c_desc, {L, D, N, DHC}));
226 
227     if (r.prop_kind == prop_kind::backward) {
228         CHECK(expect_dims(r.diff_src_layer_desc, {T, N, SLC}, false));
229         CHECK(expect_dims(r.diff_src_iter_desc, {L, D, N, SIC}));
230         CHECK(expect_dims(r.diff_src_iter_c_desc, {L, D, N, DHC}));
231         CHECK(expect_dims(
232                 r.diff_weights_layer_desc, {L, D, SLC, G, DHC}, false));
233         CHECK(expect_dims(
234                 r.diff_weights_iter_desc, {L, D, SIC, G, DHC}, false));
235         CHECK(expect_dims(r.diff_weights_peephole_desc, {L, D, 3, DHC}));
236         CHECK(expect_dims(r.diff_weights_projection_desc, {L, D, DHC, DIC}));
237         CHECK(expect_dims(r.diff_bias_desc, {L, D, G + extra_bias, DHC}));
238         CHECK(expect_dims(r.diff_dst_layer_desc, {T, N, DLC}, false));
239         CHECK(expect_dims(r.diff_dst_iter_desc, {L, D, N, DIC}));
240         CHECK(expect_dims(r.diff_dst_iter_c_desc, {L, D, N, DHC}));
241     }
242 
243     return success;
244 }
245 
rnn_common_fwd_desc_init(dnnl_rnn_desc_t * rnn_desc,prop_kind_t prop_kind,dnnl_alg_kind_t cell_kind,const rnn_direction_t direction,const memory_desc_t * src_layer_desc,const memory_desc_t * src_iter_desc,const memory_desc_t * src_iter_c_desc,const memory_desc_t * weights_layer_desc,const memory_desc_t * weights_iter_desc,const memory_desc_t * weights_peephole_desc,const memory_desc_t * weights_projection_desc,const memory_desc_t * bias_desc,const memory_desc_t * dst_layer_desc,const memory_desc_t * dst_iter_desc,const memory_desc_t * dst_iter_c_desc,unsigned flags,dnnl_alg_kind_t activation=dnnl_alg_kind_undef,float alpha=0.0f,float beta=0.0f)246 status_t rnn_common_fwd_desc_init(dnnl_rnn_desc_t *rnn_desc,
247         prop_kind_t prop_kind, dnnl_alg_kind_t cell_kind,
248         const rnn_direction_t direction, const memory_desc_t *src_layer_desc,
249         const memory_desc_t *src_iter_desc,
250         const memory_desc_t *src_iter_c_desc,
251         const memory_desc_t *weights_layer_desc,
252         const memory_desc_t *weights_iter_desc,
253         const memory_desc_t *weights_peephole_desc,
254         const memory_desc_t *weights_projection_desc,
255         const memory_desc_t *bias_desc, const memory_desc_t *dst_layer_desc,
256         const memory_desc_t *dst_iter_desc,
257         const memory_desc_t *dst_iter_c_desc, unsigned flags,
258         dnnl_alg_kind_t activation = dnnl_alg_kind_undef, float alpha = 0.0f,
259         float beta = 0.0f) {
260 
261     // check that a supported cell kind has been passed
262     bool args_ok = one_of(cell_kind, dnnl_vanilla_rnn, dnnl_vanilla_lstm,
263             dnnl_vanilla_gru, dnnl_lbr_gru);
264     if (!args_ok) return invalid_arguments;
265 
266     // check that all mandatory parameters are non-null
267     args_ok = args_ok
268             && !any_null(src_layer_desc, weights_layer_desc, weights_iter_desc,
269                     dst_layer_desc);
270     if (!args_ok) return invalid_arguments;
271 
272     if (cell_kind == dnnl_vanilla_rnn) {
273         using namespace alg_kind;
274         args_ok = args_ok
275                 && one_of(activation, eltwise_relu, eltwise_tanh,
276                         eltwise_logistic);
277         if (!args_ok) return invalid_arguments;
278     }
279 
280     if (cell_kind == dnnl_vanilla_lstm) {
281         // check if optional *_iter is provided then *_iter_c is provided too
282         args_ok = args_ok && xnor_md(src_iter_desc, src_iter_c_desc)
283                 && xnor_md(dst_iter_desc, dst_iter_c_desc);
284         if (!args_ok) return invalid_arguments;
285     }
286 
287     CHECK(check_runtime_dims_or_strides({src_layer_desc, src_iter_desc,
288             src_iter_c_desc, weights_layer_desc, weights_iter_desc,
289             weights_peephole_desc, weights_projection_desc, bias_desc,
290             dst_layer_desc, dst_iter_desc, dst_iter_c_desc}));
291 
292     // Create the descriptor
293     auto rd = rnn_desc_t();
294 
295     rd.primitive_kind = primitive_kind::rnn;
296     rd.prop_kind = prop_kind;
297     rd.cell_kind = cell_kind;
298     rd.direction = direction;
299     maybe_init_md(rd.src_layer_desc, src_layer_desc);
300     maybe_init_md(rd.src_iter_desc, src_iter_desc);
301     maybe_init_md(rd.src_iter_c_desc, src_iter_c_desc);
302     maybe_init_md(rd.weights_layer_desc, weights_layer_desc);
303     maybe_init_md(rd.weights_iter_desc, weights_iter_desc);
304     maybe_init_md(rd.weights_peephole_desc, weights_peephole_desc);
305     maybe_init_md(rd.weights_projection_desc, weights_projection_desc);
306     maybe_init_md(rd.bias_desc, bias_desc);
307     maybe_init_md(rd.dst_layer_desc, dst_layer_desc);
308     maybe_init_md(rd.dst_iter_desc, dst_iter_desc);
309     maybe_init_md(rd.dst_iter_c_desc, dst_iter_c_desc);
310 
311     rd.flags = flags;
312     rd.activation_kind = activation;
313     rd.alpha = alpha;
314     rd.beta = beta;
315 
316     CHECK(check_data_type_consistency_fwd(rd));
317     CHECK(check_dim_consistency(rd));
318 
319     *rnn_desc = rd;
320 
321     return success;
322 }
323 
rnn_common_bwd_desc_init(dnnl_rnn_desc_t * rnn_desc,dnnl_prop_kind_t prop_kind,dnnl_alg_kind_t cell_kind,const dnnl_rnn_direction_t direction,const dnnl_memory_desc_t * src_layer_desc,const dnnl_memory_desc_t * src_iter_desc,const dnnl_memory_desc_t * src_iter_c_desc,const dnnl_memory_desc_t * weights_layer_desc,const dnnl_memory_desc_t * weights_iter_desc,const dnnl_memory_desc_t * weights_peephole_desc,const dnnl_memory_desc_t * weights_projection_desc,const dnnl_memory_desc_t * bias_desc,const dnnl_memory_desc_t * dst_layer_desc,const dnnl_memory_desc_t * dst_iter_desc,const dnnl_memory_desc_t * dst_iter_c_desc,const dnnl_memory_desc_t * diff_src_layer_desc,const dnnl_memory_desc_t * diff_src_iter_desc,const dnnl_memory_desc_t * diff_src_iter_c_desc,const dnnl_memory_desc_t * diff_weights_layer_desc,const dnnl_memory_desc_t * diff_weights_iter_desc,const dnnl_memory_desc_t * diff_weights_peephole_desc,const dnnl_memory_desc_t * diff_weights_projection_desc,const dnnl_memory_desc_t * diff_bias_desc,const dnnl_memory_desc_t * diff_dst_layer_desc,const dnnl_memory_desc_t * diff_dst_iter_desc,const dnnl_memory_desc_t * diff_dst_iter_c_desc,unsigned flags,dnnl_alg_kind_t activation=dnnl_alg_kind_undef,float alpha=0.0f,float beta=0.0f)324 status_t rnn_common_bwd_desc_init(dnnl_rnn_desc_t *rnn_desc,
325         dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t cell_kind,
326         const dnnl_rnn_direction_t direction,
327         const dnnl_memory_desc_t *src_layer_desc,
328         const dnnl_memory_desc_t *src_iter_desc,
329         const dnnl_memory_desc_t *src_iter_c_desc,
330         const dnnl_memory_desc_t *weights_layer_desc,
331         const dnnl_memory_desc_t *weights_iter_desc,
332         const dnnl_memory_desc_t *weights_peephole_desc,
333         const dnnl_memory_desc_t *weights_projection_desc,
334         const dnnl_memory_desc_t *bias_desc,
335         const dnnl_memory_desc_t *dst_layer_desc,
336         const dnnl_memory_desc_t *dst_iter_desc,
337         const dnnl_memory_desc_t *dst_iter_c_desc,
338         const dnnl_memory_desc_t *diff_src_layer_desc,
339         const dnnl_memory_desc_t *diff_src_iter_desc,
340         const dnnl_memory_desc_t *diff_src_iter_c_desc,
341         const dnnl_memory_desc_t *diff_weights_layer_desc,
342         const dnnl_memory_desc_t *diff_weights_iter_desc,
343         const dnnl_memory_desc_t *diff_weights_peephole_desc,
344         const dnnl_memory_desc_t *diff_weights_projection_desc,
345         const dnnl_memory_desc_t *diff_bias_desc,
346         const dnnl_memory_desc_t *diff_dst_layer_desc,
347         const dnnl_memory_desc_t *diff_dst_iter_desc,
348         const dnnl_memory_desc_t *diff_dst_iter_c_desc, unsigned flags,
349         dnnl_alg_kind_t activation = dnnl_alg_kind_undef, float alpha = 0.0f,
350         float beta = 0.0f) {
351 
352     // check that a supported cell kind has been passed
353     bool args_ok = one_of(cell_kind, dnnl_vanilla_rnn, dnnl_vanilla_lstm,
354             dnnl_vanilla_gru, dnnl_lbr_gru);
355     if (!args_ok) return invalid_arguments;
356 
357     // check that all mandatory parameters are non-null
358     args_ok = args_ok
359             && !any_null(src_layer_desc, weights_layer_desc, weights_iter_desc,
360                     dst_layer_desc, diff_src_layer_desc,
361                     diff_weights_layer_desc, diff_weights_iter_desc,
362                     diff_dst_layer_desc);
363     if (!args_ok) return invalid_arguments;
364 
365     if (cell_kind == dnnl_vanilla_rnn) {
366         using namespace alg_kind;
367         args_ok = args_ok
368                 && one_of(activation, eltwise_relu, eltwise_tanh,
369                         eltwise_logistic);
370         if (!args_ok) return invalid_arguments;
371     }
372 
373     if (cell_kind == dnnl_vanilla_lstm) {
374         // check if optional *_iter is provided then *_iter_c is provided too
375         args_ok = args_ok && xnor_md(src_iter_desc, src_iter_c_desc)
376                 && xnor_md(dst_iter_desc, dst_iter_c_desc);
377         if (!args_ok) return invalid_arguments;
378     }
379 
380     // check if optional md is provided then diff_md is provided too
381     args_ok = args_ok && xnor_md(bias_desc, diff_bias_desc)
382             && xnor_md(weights_peephole_desc, diff_weights_peephole_desc)
383             && xnor_md(weights_projection_desc, diff_weights_projection_desc)
384             && xnor_md(src_iter_desc, diff_src_iter_desc)
385             && xnor_md(src_iter_c_desc, diff_src_iter_c_desc)
386             && xnor_md(dst_iter_desc, diff_dst_iter_desc)
387             && xnor_md(dst_iter_c_desc, diff_dst_iter_c_desc);
388     if (!args_ok) return invalid_arguments;
389 
390     CHECK(check_runtime_dims_or_strides({src_layer_desc, src_iter_desc,
391             src_iter_c_desc, weights_layer_desc, weights_iter_desc,
392             weights_peephole_desc, weights_projection_desc, bias_desc,
393             dst_layer_desc, dst_iter_desc, dst_iter_c_desc, diff_src_layer_desc,
394             diff_src_iter_desc, diff_src_iter_c_desc, diff_weights_layer_desc,
395             diff_weights_iter_desc, diff_weights_peephole_desc,
396             diff_weights_projection_desc, diff_bias_desc, diff_dst_layer_desc,
397             diff_dst_iter_desc, diff_dst_iter_c_desc}));
398 
399     auto rd = dnnl_rnn_desc_t();
400 
401     rd.primitive_kind = primitive_kind::rnn;
402     rd.prop_kind = prop_kind;
403     rd.cell_kind = cell_kind;
404     rd.direction = direction;
405 
406     maybe_init_md(rd.src_layer_desc, src_layer_desc);
407     maybe_init_md(rd.src_iter_desc, src_iter_desc);
408     maybe_init_md(rd.src_iter_c_desc, src_iter_c_desc);
409     maybe_init_md(rd.weights_layer_desc, weights_layer_desc);
410     maybe_init_md(rd.weights_iter_desc, weights_iter_desc);
411     maybe_init_md(rd.weights_peephole_desc, weights_peephole_desc);
412     maybe_init_md(rd.weights_projection_desc, weights_projection_desc);
413     maybe_init_md(rd.bias_desc, bias_desc);
414     maybe_init_md(rd.dst_layer_desc, dst_layer_desc);
415     maybe_init_md(rd.dst_iter_desc, dst_iter_desc);
416     maybe_init_md(rd.dst_iter_c_desc, dst_iter_c_desc);
417     maybe_init_md(rd.diff_src_layer_desc, diff_src_layer_desc);
418     maybe_init_md(rd.diff_src_iter_desc, diff_src_iter_desc);
419     maybe_init_md(rd.diff_src_iter_c_desc, diff_src_iter_c_desc);
420     maybe_init_md(rd.diff_weights_layer_desc, diff_weights_layer_desc);
421     maybe_init_md(rd.diff_weights_iter_desc, diff_weights_iter_desc);
422     maybe_init_md(rd.diff_weights_peephole_desc, diff_weights_peephole_desc);
423     maybe_init_md(
424             rd.diff_weights_projection_desc, diff_weights_projection_desc);
425     maybe_init_md(rd.diff_bias_desc, diff_bias_desc);
426     maybe_init_md(rd.diff_dst_layer_desc, diff_dst_layer_desc);
427     maybe_init_md(rd.diff_dst_iter_desc, diff_dst_iter_desc);
428     maybe_init_md(rd.diff_dst_iter_c_desc, diff_dst_iter_c_desc);
429 
430     rd.flags = flags;
431     rd.activation_kind = activation;
432     rd.alpha = alpha;
433     rd.beta = beta;
434 
435     CHECK(check_data_type_consistency_fwd(rd));
436     CHECK(check_data_type_consistency_bwd(rd));
437 
438     CHECK(check_dim_consistency(rd));
439 
440     *rnn_desc = rd;
441 
442     return success;
443 }
444 
445 } // namespace
446 
447 /* Public C Api */
448 
dnnl_vanilla_rnn_forward_desc_init(dnnl_rnn_desc_t * rnn_desc,dnnl_prop_kind_t prop_kind,const dnnl_alg_kind_t activation,const dnnl_rnn_direction_t direction,const dnnl_memory_desc_t * src_layer_desc,const dnnl_memory_desc_t * src_iter_desc,const dnnl_memory_desc_t * weights_layer_desc,const dnnl_memory_desc_t * weights_iter_desc,const dnnl_memory_desc_t * bias_desc,const dnnl_memory_desc_t * dst_layer_desc,const dnnl_memory_desc_t * dst_iter_desc,unsigned flags,float alpha,float beta)449 status_t dnnl_vanilla_rnn_forward_desc_init(dnnl_rnn_desc_t *rnn_desc,
450         dnnl_prop_kind_t prop_kind, const dnnl_alg_kind_t activation,
451         const dnnl_rnn_direction_t direction,
452         const dnnl_memory_desc_t *src_layer_desc,
453         const dnnl_memory_desc_t *src_iter_desc,
454         const dnnl_memory_desc_t *weights_layer_desc,
455         const dnnl_memory_desc_t *weights_iter_desc,
456         const dnnl_memory_desc_t *bias_desc,
457         const dnnl_memory_desc_t *dst_layer_desc,
458         const dnnl_memory_desc_t *dst_iter_desc, unsigned flags, float alpha,
459         float beta) {
460     status_t st = rnn_common_fwd_desc_init(rnn_desc, prop_kind,
461             dnnl_vanilla_rnn, direction, src_layer_desc, src_iter_desc, nullptr,
462             weights_layer_desc, weights_iter_desc, nullptr, nullptr, bias_desc,
463             dst_layer_desc, dst_iter_desc, nullptr, flags, activation, alpha,
464             beta);
465     return st;
466 }
467 
dnnl_lstm_forward_desc_init(dnnl_rnn_desc_t * rnn_desc,dnnl_prop_kind_t prop_kind,dnnl_rnn_direction_t direction,const dnnl_memory_desc_t * src_layer_desc,const dnnl_memory_desc_t * src_iter_desc,const dnnl_memory_desc_t * src_iter_c_desc,const dnnl_memory_desc_t * weights_layer_desc,const dnnl_memory_desc_t * weights_iter_desc,const dnnl_memory_desc_t * bias_desc,const dnnl_memory_desc_t * dst_layer_desc,const dnnl_memory_desc_t * dst_iter_desc,const dnnl_memory_desc_t * dst_iter_c_desc,unsigned flags)468 status_t dnnl_lstm_forward_desc_init(dnnl_rnn_desc_t *rnn_desc,
469         dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
470         const dnnl_memory_desc_t *src_layer_desc,
471         const dnnl_memory_desc_t *src_iter_desc,
472         const dnnl_memory_desc_t *src_iter_c_desc,
473         const dnnl_memory_desc_t *weights_layer_desc,
474         const dnnl_memory_desc_t *weights_iter_desc,
475         const dnnl_memory_desc_t *bias_desc,
476         const dnnl_memory_desc_t *dst_layer_desc,
477         const dnnl_memory_desc_t *dst_iter_desc,
478         const dnnl_memory_desc_t *dst_iter_c_desc, unsigned flags) {
479     return dnnl_lstm_forward_desc_init_v3(rnn_desc, prop_kind, direction,
480             src_layer_desc, src_iter_desc, src_iter_c_desc, weights_layer_desc,
481             weights_iter_desc, nullptr, nullptr, bias_desc, dst_layer_desc,
482             dst_iter_desc, dst_iter_c_desc, flags);
483 }
484 
dnnl_lstm_forward_desc_init_v2(dnnl_rnn_desc_t * rnn_desc,dnnl_prop_kind_t prop_kind,dnnl_rnn_direction_t direction,const dnnl_memory_desc_t * src_layer_desc,const dnnl_memory_desc_t * src_iter_desc,const dnnl_memory_desc_t * src_iter_c_desc,const dnnl_memory_desc_t * weights_layer_desc,const dnnl_memory_desc_t * weights_iter_desc,const dnnl_memory_desc_t * weights_peephole_desc,const dnnl_memory_desc_t * bias_desc,const dnnl_memory_desc_t * dst_layer_desc,const dnnl_memory_desc_t * dst_iter_desc,const dnnl_memory_desc_t * dst_iter_c_desc,unsigned flags)485 status_t dnnl_lstm_forward_desc_init_v2(dnnl_rnn_desc_t *rnn_desc,
486         dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
487         const dnnl_memory_desc_t *src_layer_desc,
488         const dnnl_memory_desc_t *src_iter_desc,
489         const dnnl_memory_desc_t *src_iter_c_desc,
490         const dnnl_memory_desc_t *weights_layer_desc,
491         const dnnl_memory_desc_t *weights_iter_desc,
492         const dnnl_memory_desc_t *weights_peephole_desc,
493         const dnnl_memory_desc_t *bias_desc,
494         const dnnl_memory_desc_t *dst_layer_desc,
495         const dnnl_memory_desc_t *dst_iter_desc,
496         const dnnl_memory_desc_t *dst_iter_c_desc, unsigned flags) {
497     return dnnl_lstm_forward_desc_init_v3(rnn_desc, prop_kind, direction,
498             src_layer_desc, src_iter_desc, src_iter_c_desc, weights_layer_desc,
499             weights_iter_desc, weights_peephole_desc, nullptr, bias_desc,
500             dst_layer_desc, dst_iter_desc, dst_iter_c_desc, flags);
501 }
502 
dnnl_lstm_forward_desc_init_v3(dnnl_rnn_desc_t * rnn_desc,dnnl_prop_kind_t prop_kind,dnnl_rnn_direction_t direction,const dnnl_memory_desc_t * src_layer_desc,const dnnl_memory_desc_t * src_iter_desc,const dnnl_memory_desc_t * src_iter_c_desc,const dnnl_memory_desc_t * weights_layer_desc,const dnnl_memory_desc_t * weights_iter_desc,const dnnl_memory_desc_t * weights_peephole_desc,const dnnl_memory_desc_t * weights_projection_desc,const dnnl_memory_desc_t * bias_desc,const dnnl_memory_desc_t * dst_layer_desc,const dnnl_memory_desc_t * dst_iter_desc,const dnnl_memory_desc_t * dst_iter_c_desc,unsigned flags)503 status_t dnnl_lstm_forward_desc_init_v3(dnnl_rnn_desc_t *rnn_desc,
504         dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
505         const dnnl_memory_desc_t *src_layer_desc,
506         const dnnl_memory_desc_t *src_iter_desc,
507         const dnnl_memory_desc_t *src_iter_c_desc,
508         const dnnl_memory_desc_t *weights_layer_desc,
509         const dnnl_memory_desc_t *weights_iter_desc,
510         const dnnl_memory_desc_t *weights_peephole_desc,
511         const dnnl_memory_desc_t *weights_projection_desc,
512         const dnnl_memory_desc_t *bias_desc,
513         const dnnl_memory_desc_t *dst_layer_desc,
514         const dnnl_memory_desc_t *dst_iter_desc,
515         const dnnl_memory_desc_t *dst_iter_c_desc, unsigned flags) {
516     status_t st = rnn_common_fwd_desc_init(rnn_desc, prop_kind,
517             dnnl_vanilla_lstm, direction, src_layer_desc, src_iter_desc,
518             src_iter_c_desc, weights_layer_desc, weights_iter_desc,
519             weights_peephole_desc, weights_projection_desc, bias_desc,
520             dst_layer_desc, dst_iter_desc, dst_iter_c_desc, flags);
521     return st;
522 }
523 
dnnl_gru_forward_desc_init(dnnl_rnn_desc_t * rnn_desc,dnnl_prop_kind_t prop_kind,dnnl_rnn_direction_t direction,const dnnl_memory_desc_t * src_layer_desc,const dnnl_memory_desc_t * src_iter_desc,const dnnl_memory_desc_t * weights_layer_desc,const dnnl_memory_desc_t * weights_iter_desc,const dnnl_memory_desc_t * bias_desc,const dnnl_memory_desc_t * dst_layer_desc,const dnnl_memory_desc_t * dst_iter_desc,unsigned flags)524 status_t dnnl_gru_forward_desc_init(dnnl_rnn_desc_t *rnn_desc,
525         dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
526         const dnnl_memory_desc_t *src_layer_desc,
527         const dnnl_memory_desc_t *src_iter_desc,
528         const dnnl_memory_desc_t *weights_layer_desc,
529         const dnnl_memory_desc_t *weights_iter_desc,
530         const dnnl_memory_desc_t *bias_desc,
531         const dnnl_memory_desc_t *dst_layer_desc,
532         const dnnl_memory_desc_t *dst_iter_desc, unsigned flags) {
533     status_t st = rnn_common_fwd_desc_init(rnn_desc, prop_kind,
534             dnnl_vanilla_gru, direction, src_layer_desc, src_iter_desc, nullptr,
535             weights_layer_desc, weights_iter_desc, nullptr, nullptr, bias_desc,
536             dst_layer_desc, dst_iter_desc, nullptr, flags);
537     return st;
538 }
539 
dnnl_lbr_gru_forward_desc_init(dnnl_rnn_desc_t * rnn_desc,dnnl_prop_kind_t prop_kind,dnnl_rnn_direction_t direction,const dnnl_memory_desc_t * src_layer_desc,const dnnl_memory_desc_t * src_iter_desc,const dnnl_memory_desc_t * weights_layer_desc,const dnnl_memory_desc_t * weights_iter_desc,const dnnl_memory_desc_t * bias_desc,const dnnl_memory_desc_t * dst_layer_desc,const dnnl_memory_desc_t * dst_iter_desc,unsigned flags)540 status_t dnnl_lbr_gru_forward_desc_init(dnnl_rnn_desc_t *rnn_desc,
541         dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
542         const dnnl_memory_desc_t *src_layer_desc,
543         const dnnl_memory_desc_t *src_iter_desc,
544         const dnnl_memory_desc_t *weights_layer_desc,
545         const dnnl_memory_desc_t *weights_iter_desc,
546         const dnnl_memory_desc_t *bias_desc,
547         const dnnl_memory_desc_t *dst_layer_desc,
548         const dnnl_memory_desc_t *dst_iter_desc, unsigned flags) {
549     status_t st = rnn_common_fwd_desc_init(rnn_desc, prop_kind, dnnl_lbr_gru,
550             direction, src_layer_desc, src_iter_desc, nullptr,
551             weights_layer_desc, weights_iter_desc, nullptr, nullptr, bias_desc,
552             dst_layer_desc, dst_iter_desc, nullptr, flags);
553     return st;
554 }
555 
dnnl_vanilla_rnn_backward_desc_init(dnnl_rnn_desc_t * rnn_desc,dnnl_prop_kind_t prop_kind,const dnnl_alg_kind_t activation,const dnnl_rnn_direction_t direction,const dnnl_memory_desc_t * src_layer_desc,const dnnl_memory_desc_t * src_iter_desc,const dnnl_memory_desc_t * weights_layer_desc,const dnnl_memory_desc_t * weights_iter_desc,const dnnl_memory_desc_t * bias_desc,const dnnl_memory_desc_t * dst_layer_desc,const dnnl_memory_desc_t * dst_iter_desc,const dnnl_memory_desc_t * diff_src_layer_desc,const dnnl_memory_desc_t * diff_src_iter_desc,const dnnl_memory_desc_t * diff_weights_layer_desc,const dnnl_memory_desc_t * diff_weights_iter_desc,const dnnl_memory_desc_t * diff_bias_desc,const dnnl_memory_desc_t * diff_dst_layer_desc,const dnnl_memory_desc_t * diff_dst_iter_desc,unsigned flags,float alpha,float beta)556 status_t dnnl_vanilla_rnn_backward_desc_init(dnnl_rnn_desc_t *rnn_desc,
557         dnnl_prop_kind_t prop_kind, const dnnl_alg_kind_t activation,
558         const dnnl_rnn_direction_t direction,
559         const dnnl_memory_desc_t *src_layer_desc,
560         const dnnl_memory_desc_t *src_iter_desc,
561         const dnnl_memory_desc_t *weights_layer_desc,
562         const dnnl_memory_desc_t *weights_iter_desc,
563         const dnnl_memory_desc_t *bias_desc,
564         const dnnl_memory_desc_t *dst_layer_desc,
565         const dnnl_memory_desc_t *dst_iter_desc,
566         const dnnl_memory_desc_t *diff_src_layer_desc,
567         const dnnl_memory_desc_t *diff_src_iter_desc,
568         const dnnl_memory_desc_t *diff_weights_layer_desc,
569         const dnnl_memory_desc_t *diff_weights_iter_desc,
570         const dnnl_memory_desc_t *diff_bias_desc,
571         const dnnl_memory_desc_t *diff_dst_layer_desc,
572         const dnnl_memory_desc_t *diff_dst_iter_desc, unsigned flags,
573         float alpha, float beta) {
574     status_t st = rnn_common_bwd_desc_init(rnn_desc, prop_kind,
575             dnnl_vanilla_rnn, direction, src_layer_desc, src_iter_desc, nullptr,
576             weights_layer_desc, weights_iter_desc, nullptr, nullptr, bias_desc,
577             dst_layer_desc, dst_iter_desc, nullptr, diff_src_layer_desc,
578             diff_src_iter_desc, nullptr, diff_weights_layer_desc,
579             diff_weights_iter_desc, nullptr, nullptr, diff_bias_desc,
580             diff_dst_layer_desc, diff_dst_iter_desc, nullptr, flags, activation,
581             alpha, beta);
582     return st;
583 }
584 
dnnl_lstm_backward_desc_init(dnnl_rnn_desc_t * rnn_desc,dnnl_prop_kind_t prop_kind,dnnl_rnn_direction_t direction,const dnnl_memory_desc_t * src_layer_desc,const dnnl_memory_desc_t * src_iter_desc,const dnnl_memory_desc_t * src_iter_c_desc,const dnnl_memory_desc_t * weights_layer_desc,const dnnl_memory_desc_t * weights_iter_desc,const dnnl_memory_desc_t * bias_desc,const dnnl_memory_desc_t * dst_layer_desc,const dnnl_memory_desc_t * dst_iter_desc,const dnnl_memory_desc_t * dst_iter_c_desc,const dnnl_memory_desc_t * diff_src_layer_desc,const dnnl_memory_desc_t * diff_src_iter_desc,const dnnl_memory_desc_t * diff_src_iter_c_desc,const dnnl_memory_desc_t * diff_weights_layer_desc,const dnnl_memory_desc_t * diff_weights_iter_desc,const dnnl_memory_desc_t * diff_bias_desc,const dnnl_memory_desc_t * diff_dst_layer_desc,const dnnl_memory_desc_t * diff_dst_iter_desc,const dnnl_memory_desc_t * diff_dst_iter_c_desc,unsigned flags)585 status_t dnnl_lstm_backward_desc_init(dnnl_rnn_desc_t *rnn_desc,
586         dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
587         const dnnl_memory_desc_t *src_layer_desc,
588         const dnnl_memory_desc_t *src_iter_desc,
589         const dnnl_memory_desc_t *src_iter_c_desc,
590         const dnnl_memory_desc_t *weights_layer_desc,
591         const dnnl_memory_desc_t *weights_iter_desc,
592         const dnnl_memory_desc_t *bias_desc,
593         const dnnl_memory_desc_t *dst_layer_desc,
594         const dnnl_memory_desc_t *dst_iter_desc,
595         const dnnl_memory_desc_t *dst_iter_c_desc,
596         const dnnl_memory_desc_t *diff_src_layer_desc,
597         const dnnl_memory_desc_t *diff_src_iter_desc,
598         const dnnl_memory_desc_t *diff_src_iter_c_desc,
599         const dnnl_memory_desc_t *diff_weights_layer_desc,
600         const dnnl_memory_desc_t *diff_weights_iter_desc,
601         const dnnl_memory_desc_t *diff_bias_desc,
602         const dnnl_memory_desc_t *diff_dst_layer_desc,
603         const dnnl_memory_desc_t *diff_dst_iter_desc,
604         const dnnl_memory_desc_t *diff_dst_iter_c_desc, unsigned flags) {
605     return dnnl_lstm_backward_desc_init_v3(rnn_desc, prop_kind, direction,
606             src_layer_desc, src_iter_desc, src_iter_c_desc, weights_layer_desc,
607             weights_iter_desc, nullptr, nullptr, bias_desc, dst_layer_desc,
608             dst_iter_desc, dst_iter_c_desc, diff_src_layer_desc,
609             diff_src_iter_desc, diff_src_iter_c_desc, diff_weights_layer_desc,
610             diff_weights_iter_desc, nullptr, nullptr, diff_bias_desc,
611             diff_dst_layer_desc, diff_dst_iter_desc, diff_dst_iter_c_desc,
612             flags);
613 }
614 
dnnl_lstm_backward_desc_init_v2(dnnl_rnn_desc_t * rnn_desc,dnnl_prop_kind_t prop_kind,dnnl_rnn_direction_t direction,const dnnl_memory_desc_t * src_layer_desc,const dnnl_memory_desc_t * src_iter_desc,const dnnl_memory_desc_t * src_iter_c_desc,const dnnl_memory_desc_t * weights_layer_desc,const dnnl_memory_desc_t * weights_iter_desc,const dnnl_memory_desc_t * weights_peephole_desc,const dnnl_memory_desc_t * bias_desc,const dnnl_memory_desc_t * dst_layer_desc,const dnnl_memory_desc_t * dst_iter_desc,const dnnl_memory_desc_t * dst_iter_c_desc,const dnnl_memory_desc_t * diff_src_layer_desc,const dnnl_memory_desc_t * diff_src_iter_desc,const dnnl_memory_desc_t * diff_src_iter_c_desc,const dnnl_memory_desc_t * diff_weights_layer_desc,const dnnl_memory_desc_t * diff_weights_iter_desc,const dnnl_memory_desc_t * diff_weights_peephole_desc,const dnnl_memory_desc_t * diff_bias_desc,const dnnl_memory_desc_t * diff_dst_layer_desc,const dnnl_memory_desc_t * diff_dst_iter_desc,const dnnl_memory_desc_t * diff_dst_iter_c_desc,unsigned flags)615 status_t dnnl_lstm_backward_desc_init_v2(dnnl_rnn_desc_t *rnn_desc,
616         dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
617         const dnnl_memory_desc_t *src_layer_desc,
618         const dnnl_memory_desc_t *src_iter_desc,
619         const dnnl_memory_desc_t *src_iter_c_desc,
620         const dnnl_memory_desc_t *weights_layer_desc,
621         const dnnl_memory_desc_t *weights_iter_desc,
622         const dnnl_memory_desc_t *weights_peephole_desc,
623         const dnnl_memory_desc_t *bias_desc,
624         const dnnl_memory_desc_t *dst_layer_desc,
625         const dnnl_memory_desc_t *dst_iter_desc,
626         const dnnl_memory_desc_t *dst_iter_c_desc,
627         const dnnl_memory_desc_t *diff_src_layer_desc,
628         const dnnl_memory_desc_t *diff_src_iter_desc,
629         const dnnl_memory_desc_t *diff_src_iter_c_desc,
630         const dnnl_memory_desc_t *diff_weights_layer_desc,
631         const dnnl_memory_desc_t *diff_weights_iter_desc,
632         const dnnl_memory_desc_t *diff_weights_peephole_desc,
633         const dnnl_memory_desc_t *diff_bias_desc,
634         const dnnl_memory_desc_t *diff_dst_layer_desc,
635         const dnnl_memory_desc_t *diff_dst_iter_desc,
636         const dnnl_memory_desc_t *diff_dst_iter_c_desc, unsigned flags) {
637     return dnnl_lstm_backward_desc_init_v3(rnn_desc, prop_kind, direction,
638             src_layer_desc, src_iter_desc, src_iter_c_desc, weights_layer_desc,
639             weights_iter_desc, weights_peephole_desc, nullptr, bias_desc,
640             dst_layer_desc, dst_iter_desc, dst_iter_c_desc, diff_src_layer_desc,
641             diff_src_iter_desc, diff_src_iter_c_desc, diff_weights_layer_desc,
642             diff_weights_iter_desc, diff_weights_peephole_desc, nullptr,
643             diff_bias_desc, diff_dst_layer_desc, diff_dst_iter_desc,
644             diff_dst_iter_c_desc, flags);
645 }
646 
dnnl_lstm_backward_desc_init_v3(dnnl_rnn_desc_t * rnn_desc,dnnl_prop_kind_t prop_kind,dnnl_rnn_direction_t direction,const dnnl_memory_desc_t * src_layer_desc,const dnnl_memory_desc_t * src_iter_desc,const dnnl_memory_desc_t * src_iter_c_desc,const dnnl_memory_desc_t * weights_layer_desc,const dnnl_memory_desc_t * weights_iter_desc,const dnnl_memory_desc_t * weights_peephole_desc,const dnnl_memory_desc_t * weights_projection_desc,const dnnl_memory_desc_t * bias_desc,const dnnl_memory_desc_t * dst_layer_desc,const dnnl_memory_desc_t * dst_iter_desc,const dnnl_memory_desc_t * dst_iter_c_desc,const dnnl_memory_desc_t * diff_src_layer_desc,const dnnl_memory_desc_t * diff_src_iter_desc,const dnnl_memory_desc_t * diff_src_iter_c_desc,const dnnl_memory_desc_t * diff_weights_layer_desc,const dnnl_memory_desc_t * diff_weights_iter_desc,const dnnl_memory_desc_t * diff_weights_peephole_desc,const dnnl_memory_desc_t * diff_weights_projection_desc,const dnnl_memory_desc_t * diff_bias_desc,const dnnl_memory_desc_t * diff_dst_layer_desc,const dnnl_memory_desc_t * diff_dst_iter_desc,const dnnl_memory_desc_t * diff_dst_iter_c_desc,unsigned flags)647 status_t dnnl_lstm_backward_desc_init_v3(dnnl_rnn_desc_t *rnn_desc,
648         dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
649         const dnnl_memory_desc_t *src_layer_desc,
650         const dnnl_memory_desc_t *src_iter_desc,
651         const dnnl_memory_desc_t *src_iter_c_desc,
652         const dnnl_memory_desc_t *weights_layer_desc,
653         const dnnl_memory_desc_t *weights_iter_desc,
654         const dnnl_memory_desc_t *weights_peephole_desc,
655         const dnnl_memory_desc_t *weights_projection_desc,
656         const dnnl_memory_desc_t *bias_desc,
657         const dnnl_memory_desc_t *dst_layer_desc,
658         const dnnl_memory_desc_t *dst_iter_desc,
659         const dnnl_memory_desc_t *dst_iter_c_desc,
660         const dnnl_memory_desc_t *diff_src_layer_desc,
661         const dnnl_memory_desc_t *diff_src_iter_desc,
662         const dnnl_memory_desc_t *diff_src_iter_c_desc,
663         const dnnl_memory_desc_t *diff_weights_layer_desc,
664         const dnnl_memory_desc_t *diff_weights_iter_desc,
665         const dnnl_memory_desc_t *diff_weights_peephole_desc,
666         const dnnl_memory_desc_t *diff_weights_projection_desc,
667         const dnnl_memory_desc_t *diff_bias_desc,
668         const dnnl_memory_desc_t *diff_dst_layer_desc,
669         const dnnl_memory_desc_t *diff_dst_iter_desc,
670         const dnnl_memory_desc_t *diff_dst_iter_c_desc, unsigned flags) {
671     status_t st = rnn_common_bwd_desc_init(rnn_desc, prop_kind,
672             dnnl_vanilla_lstm, direction, src_layer_desc, src_iter_desc,
673             src_iter_c_desc, weights_layer_desc, weights_iter_desc,
674             weights_peephole_desc, weights_projection_desc, bias_desc,
675             dst_layer_desc, dst_iter_desc, dst_iter_c_desc, diff_src_layer_desc,
676             diff_src_iter_desc, diff_src_iter_c_desc, diff_weights_layer_desc,
677             diff_weights_iter_desc, diff_weights_peephole_desc,
678             diff_weights_projection_desc, diff_bias_desc, diff_dst_layer_desc,
679             diff_dst_iter_desc, diff_dst_iter_c_desc, flags);
680     return st;
681 }
682 
dnnl_gru_backward_desc_init(dnnl_rnn_desc_t * rnn_desc,dnnl_prop_kind_t prop_kind,dnnl_rnn_direction_t direction,const dnnl_memory_desc_t * src_layer_desc,const dnnl_memory_desc_t * src_iter_desc,const dnnl_memory_desc_t * weights_layer_desc,const dnnl_memory_desc_t * weights_iter_desc,const dnnl_memory_desc_t * bias_desc,const dnnl_memory_desc_t * dst_layer_desc,const dnnl_memory_desc_t * dst_iter_desc,const dnnl_memory_desc_t * diff_src_layer_desc,const dnnl_memory_desc_t * diff_src_iter_desc,const dnnl_memory_desc_t * diff_weights_layer_desc,const dnnl_memory_desc_t * diff_weights_iter_desc,const dnnl_memory_desc_t * diff_bias_desc,const dnnl_memory_desc_t * diff_dst_layer_desc,const dnnl_memory_desc_t * diff_dst_iter_desc,unsigned flags)683 status_t dnnl_gru_backward_desc_init(dnnl_rnn_desc_t *rnn_desc,
684         dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
685         const dnnl_memory_desc_t *src_layer_desc,
686         const dnnl_memory_desc_t *src_iter_desc,
687         const dnnl_memory_desc_t *weights_layer_desc,
688         const dnnl_memory_desc_t *weights_iter_desc,
689         const dnnl_memory_desc_t *bias_desc,
690         const dnnl_memory_desc_t *dst_layer_desc,
691         const dnnl_memory_desc_t *dst_iter_desc,
692         const dnnl_memory_desc_t *diff_src_layer_desc,
693         const dnnl_memory_desc_t *diff_src_iter_desc,
694         const dnnl_memory_desc_t *diff_weights_layer_desc,
695         const dnnl_memory_desc_t *diff_weights_iter_desc,
696         const dnnl_memory_desc_t *diff_bias_desc,
697         const dnnl_memory_desc_t *diff_dst_layer_desc,
698         const dnnl_memory_desc_t *diff_dst_iter_desc, unsigned flags) {
699     status_t st = rnn_common_bwd_desc_init(rnn_desc, prop_kind,
700             dnnl_vanilla_gru, direction, src_layer_desc, src_iter_desc, nullptr,
701             weights_layer_desc, weights_iter_desc, nullptr, nullptr, bias_desc,
702             dst_layer_desc, dst_iter_desc, nullptr, diff_src_layer_desc,
703             diff_src_iter_desc, nullptr, diff_weights_layer_desc,
704             diff_weights_iter_desc, nullptr, nullptr, diff_bias_desc,
705             diff_dst_layer_desc, diff_dst_iter_desc, nullptr, flags);
706     return st;
707 }
708 
dnnl_lbr_gru_backward_desc_init(dnnl_rnn_desc_t * rnn_desc,dnnl_prop_kind_t prop_kind,dnnl_rnn_direction_t direction,const dnnl_memory_desc_t * src_layer_desc,const dnnl_memory_desc_t * src_iter_desc,const dnnl_memory_desc_t * weights_layer_desc,const dnnl_memory_desc_t * weights_iter_desc,const dnnl_memory_desc_t * bias_desc,const dnnl_memory_desc_t * dst_layer_desc,const dnnl_memory_desc_t * dst_iter_desc,const dnnl_memory_desc_t * diff_src_layer_desc,const dnnl_memory_desc_t * diff_src_iter_desc,const dnnl_memory_desc_t * diff_weights_layer_desc,const dnnl_memory_desc_t * diff_weights_iter_desc,const dnnl_memory_desc_t * diff_bias_desc,const dnnl_memory_desc_t * diff_dst_layer_desc,const dnnl_memory_desc_t * diff_dst_iter_desc,unsigned flags)709 status_t dnnl_lbr_gru_backward_desc_init(dnnl_rnn_desc_t *rnn_desc,
710         dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
711         const dnnl_memory_desc_t *src_layer_desc,
712         const dnnl_memory_desc_t *src_iter_desc,
713         const dnnl_memory_desc_t *weights_layer_desc,
714         const dnnl_memory_desc_t *weights_iter_desc,
715         const dnnl_memory_desc_t *bias_desc,
716         const dnnl_memory_desc_t *dst_layer_desc,
717         const dnnl_memory_desc_t *dst_iter_desc,
718         const dnnl_memory_desc_t *diff_src_layer_desc,
719         const dnnl_memory_desc_t *diff_src_iter_desc,
720         const dnnl_memory_desc_t *diff_weights_layer_desc,
721         const dnnl_memory_desc_t *diff_weights_iter_desc,
722         const dnnl_memory_desc_t *diff_bias_desc,
723         const dnnl_memory_desc_t *diff_dst_layer_desc,
724         const dnnl_memory_desc_t *diff_dst_iter_desc, unsigned flags) {
725     status_t st = rnn_common_bwd_desc_init(rnn_desc, prop_kind, dnnl_lbr_gru,
726             direction, src_layer_desc, src_iter_desc, nullptr,
727             weights_layer_desc, weights_iter_desc, nullptr, nullptr, bias_desc,
728             dst_layer_desc, dst_iter_desc, nullptr, diff_src_layer_desc,
729             diff_src_iter_desc, nullptr, diff_weights_layer_desc,
730             diff_weights_iter_desc, nullptr, nullptr, diff_bias_desc,
731             diff_dst_layer_desc, diff_dst_iter_desc, nullptr, flags);
732     return st;
733 }
734