1 /*******************************************************************************
2 * Copyright 2018-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <initializer_list>
18
19 #include "oneapi/dnnl/dnnl.h"
20
21 #include "c_types_map.hpp"
22 #include "type_helpers.hpp"
23 #include "utils.hpp"
24
25 namespace dnnl {
26 namespace impl {
27 namespace rnn {
28
get_gates_count(dnnl_alg_kind_t cell_kind)29 int get_gates_count(dnnl_alg_kind_t cell_kind) {
30 switch (cell_kind) {
31 case dnnl::impl::alg_kind::vanilla_rnn: return 1;
32 case dnnl::impl::alg_kind::vanilla_gru: return 3;
33 case dnnl::impl::alg_kind::lbr_gru: return 3;
34 case dnnl::impl::alg_kind::vanilla_lstm: return 4;
35 default: assert(!"unknown cell kind"); return 0;
36 }
37 return 0;
38 }
39
40 } // namespace rnn
41 } // namespace impl
42 } // namespace dnnl
43
44 namespace {
45 using namespace dnnl::impl;
46 using namespace dnnl::impl::status;
47 using namespace dnnl::impl::types;
48 using namespace dnnl::impl::utils;
49
maybe_init_md(memory_desc_t & md,const memory_desc_t * with_md)50 void maybe_init_md(memory_desc_t &md, const memory_desc_t *with_md) {
51 if (with_md) md = *with_md;
52 }
53
xnor_md(const memory_desc_t * a_md,const memory_desc_t * b_md)54 bool xnor_md(const memory_desc_t *a_md, const memory_desc_t *b_md) {
55 return is_zero_md(a_md) == is_zero_md(b_md);
56 }
57
check_runtime_dims_or_strides(std::initializer_list<const memory_desc_t * > l)58 status_t check_runtime_dims_or_strides(
59 std::initializer_list<const memory_desc_t *> l) {
60 bool runtime_dims_or_strides = false;
61 for (auto md : l)
62 runtime_dims_or_strides = runtime_dims_or_strides
63 || memory_desc_wrapper(md).has_runtime_dims_or_strides();
64 return runtime_dims_or_strides ? unimplemented : success;
65 }
66
67 template <typename... DTs>
expect_dt(const memory_desc_t & md,DTs...dts)68 bool expect_dt(const memory_desc_t &md, DTs... dts) {
69 return IMPLICATION(!is_zero_md(&md), utils::one_of(md.data_type, dts...));
70 }
71
expect_dims(const memory_desc_t & md,std::initializer_list<dim_t> dims,bool allow_zero=true)72 status_t expect_dims(const memory_desc_t &md, std::initializer_list<dim_t> dims,
73 bool allow_zero = true) {
74 if (is_zero_md(&md))
75 return (allow_zero || dims.size() == 0) ? success : invalid_arguments;
76
77 if (md.ndims != (int)dims.size()) return invalid_arguments;
78
79 int d_in_md = 0;
80 for (auto d : dims)
81 if (d != md.dims[d_in_md++]) return invalid_arguments;
82
83 return success;
84 }
85
check_data_type_consistency_fwd(const rnn_desc_t & r)86 status_t check_data_type_consistency_fwd(const rnn_desc_t &r) {
87 using namespace data_type;
88 data_type_t src_layer_dt = r.src_layer_desc.data_type;
89 data_type_t dst_layer_dt = r.dst_layer_desc.data_type;
90 data_type_t weights_iter_dt = r.weights_iter_desc.data_type;
91 data_type_t weights_layer_dt = r.weights_layer_desc.data_type;
92 data_type_t weights_projection_dt = r.weights_projection_desc.data_type;
93
94 const bool is_forward = !(r.prop_kind == prop_kind::backward);
95 const bool is_inference = r.prop_kind == prop_kind::forward_inference;
96 const bool is_int8_ok
97 = one_of(r.cell_kind, dnnl_vanilla_lstm, dnnl_vanilla_gru);
98
99 const bool cell_state_check = expect_dt(r.src_iter_c_desc, f32, f16)
100 && expect_dt(r.dst_iter_c_desc, f32, f16);
101
102 const bool is_f32 = everyone_is(f32, src_layer_dt, dst_layer_dt,
103 weights_iter_dt, weights_layer_dt)
104 && expect_dt(r.src_iter_desc, f32)
105 && expect_dt(r.weights_peephole_desc, f32)
106 && expect_dt(r.weights_projection_desc, f32)
107 && expect_dt(r.dst_iter_desc, f32) && expect_dt(r.bias_desc, f32);
108
109 const bool is_bf16 = everyone_is(bf16, src_layer_dt, dst_layer_dt,
110 weights_iter_dt, weights_layer_dt)
111 && expect_dt(r.src_iter_desc, bf16)
112 && expect_dt(r.weights_peephole_desc, f32)
113 && one_of(weights_projection_dt, bf16, data_type::undef)
114 && expect_dt(r.dst_iter_desc, bf16) && expect_dt(r.bias_desc, f32);
115
116 const bool is_f16 = is_forward
117 && everyone_is(f16, src_layer_dt, dst_layer_dt, weights_iter_dt,
118 weights_layer_dt)
119 && expect_dt(r.src_iter_desc, f16)
120 && expect_dt(r.weights_peephole_desc, f16)
121 && r.weights_peephole_desc.data_type == data_type::undef
122 && expect_dt(r.dst_iter_desc, f16) && expect_dt(r.bias_desc, f16);
123
124 const bool is_u8u8u8 = is_inference && is_int8_ok && src_layer_dt == u8
125 && one_of(dst_layer_dt, u8, f32)
126 && everyone_is(s8, weights_iter_dt, weights_layer_dt)
127 && expect_dt(r.src_iter_desc, u8)
128 && expect_dt(r.src_iter_c_desc, f32)
129 && r.weights_peephole_desc.data_type == data_type::undef
130 && one_of(weights_projection_dt, s8, data_type::undef)
131 && expect_dt(r.dst_iter_desc, u8)
132 && expect_dt(r.dst_iter_c_desc, f32) && expect_dt(r.bias_desc, f32);
133
134 const bool is_f32u8f32 = is_inference && is_int8_ok && src_layer_dt == u8
135 && everyone_is(s8, weights_iter_dt, weights_layer_dt)
136 && r.weights_peephole_desc.data_type == data_type::undef
137 && one_of(weights_projection_dt, s8, data_type::undef)
138 && one_of(dst_layer_dt, u8, f32) && expect_dt(r.src_iter_desc, f32)
139 && expect_dt(r.dst_iter_desc, f32) && expect_dt(r.bias_desc, f32);
140
141 const bool is_s8s8s8 = is_inference && is_int8_ok && src_layer_dt == s8
142 && one_of(dst_layer_dt, s8, f32)
143 && everyone_is(s8, weights_iter_dt, weights_layer_dt)
144 && expect_dt(r.src_iter_desc, s8)
145 && expect_dt(r.src_iter_c_desc, f32)
146 && r.weights_peephole_desc.data_type == data_type::undef
147 && one_of(weights_projection_dt, s8, data_type::undef)
148 && expect_dt(r.dst_iter_desc, s8)
149 && expect_dt(r.dst_iter_c_desc, f32) && expect_dt(r.bias_desc, f32);
150
151 const bool is_f32s8f32 = is_inference && is_int8_ok && src_layer_dt == s8
152 && everyone_is(s8, weights_iter_dt, weights_layer_dt)
153 && r.weights_peephole_desc.data_type == data_type::undef
154 && one_of(weights_projection_dt, s8, data_type::undef)
155 && one_of(dst_layer_dt, s8, f32) && expect_dt(r.src_iter_desc, f32)
156 && expect_dt(r.dst_iter_desc, f32) && expect_dt(r.bias_desc, f32);
157
158 return cell_state_check
159 && (is_f32 || is_bf16 || is_f16 || is_u8u8u8 || is_f32u8f32
160 || is_s8s8s8 || is_f32s8f32)
161 ? success
162 : unimplemented;
163 }
164
check_data_type_consistency_bwd(const rnn_desc_t & r)165 status_t check_data_type_consistency_bwd(const rnn_desc_t &r) {
166 using namespace data_type;
167
168 /* We require diffs to be f32, even for bf16 */
169 bool are_diff_f32 = everyone_is(f32, r.diff_src_layer_desc.data_type,
170 r.diff_dst_layer_desc.data_type,
171 r.diff_weights_iter_desc.data_type,
172 r.diff_weights_layer_desc.data_type)
173 && expect_dt(r.diff_src_iter_desc, f32)
174 && expect_dt(r.diff_dst_iter_desc, f32)
175 && expect_dt(r.diff_weights_peephole_desc, f32)
176 && expect_dt(r.diff_weights_projection_desc, f32)
177 && expect_dt(r.diff_bias_desc, f32)
178 && expect_dt(r.diff_src_iter_c_desc, f32)
179 && expect_dt(r.diff_dst_iter_c_desc, f32);
180
181 return are_diff_f32 ? success : unimplemented;
182 }
183
check_dim_consistency(const rnn_desc_t & r)184 status_t check_dim_consistency(const rnn_desc_t &r) {
185 const bool is_lstm_projection = r.cell_kind == dnnl_vanilla_lstm
186 && !is_zero_md(&r.weights_projection_desc);
187
188 const dim_t L = r.weights_layer_desc.dims[0];
189 const dim_t T = r.src_layer_desc.dims[0];
190 const dim_t N = r.src_layer_desc.dims[1];
191 const dim_t D = one_of(r.direction, dnnl_unidirectional_left2right,
192 dnnl_unidirectional_right2left)
193 ? 1
194 : 2;
195 const dim_t G = rnn::get_gates_count(r.cell_kind);
196 const dim_t SLC = r.src_layer_desc.dims[2];
197 const dim_t SIC = r.weights_iter_desc.dims[2];
198 const dim_t DLC = r.dst_layer_desc.dims[2];
199 const dim_t DHC = r.weights_layer_desc.dims[4];
200 const dim_t DIC
201 = is_lstm_projection ? r.weights_projection_desc.dims[3] : DHC;
202
203 const bool extra_bias = r.cell_kind == alg_kind::lbr_gru;
204 const dim_t dlc_multiplier
205 = (r.direction == dnnl_bidirectional_concat) ? 2 : 1;
206
207 bool args_ok = IMPLICATION(utils::one_of(r.cell_kind, alg_kind::vanilla_gru,
208 alg_kind::lbr_gru),
209 SIC == DHC)
210 && dlc_multiplier * DIC == DLC
211 && IMPLICATION(L > 1, dlc_multiplier * SLC == DLC)
212 && IMPLICATION(T > 1, SIC == DIC);
213 if (!args_ok) return invalid_arguments;
214
215 CHECK(expect_dims(r.src_layer_desc, {T, N, SLC}, false));
216 CHECK(expect_dims(r.src_iter_desc, {L, D, N, SIC}));
217 CHECK(expect_dims(r.src_iter_c_desc, {L, D, N, DHC}));
218 CHECK(expect_dims(r.weights_layer_desc, {L, D, SLC, G, DHC}, false));
219 CHECK(expect_dims(r.weights_iter_desc, {L, D, SIC, G, DHC}, false));
220 CHECK(expect_dims(r.weights_peephole_desc, {L, D, 3, DHC}));
221 CHECK(expect_dims(r.weights_projection_desc, {L, D, DHC, DIC}));
222 CHECK(expect_dims(r.bias_desc, {L, D, G + extra_bias, DHC}));
223 CHECK(expect_dims(r.dst_layer_desc, {T, N, DLC}, false));
224 CHECK(expect_dims(r.dst_iter_desc, {L, D, N, DIC}));
225 CHECK(expect_dims(r.dst_iter_c_desc, {L, D, N, DHC}));
226
227 if (r.prop_kind == prop_kind::backward) {
228 CHECK(expect_dims(r.diff_src_layer_desc, {T, N, SLC}, false));
229 CHECK(expect_dims(r.diff_src_iter_desc, {L, D, N, SIC}));
230 CHECK(expect_dims(r.diff_src_iter_c_desc, {L, D, N, DHC}));
231 CHECK(expect_dims(
232 r.diff_weights_layer_desc, {L, D, SLC, G, DHC}, false));
233 CHECK(expect_dims(
234 r.diff_weights_iter_desc, {L, D, SIC, G, DHC}, false));
235 CHECK(expect_dims(r.diff_weights_peephole_desc, {L, D, 3, DHC}));
236 CHECK(expect_dims(r.diff_weights_projection_desc, {L, D, DHC, DIC}));
237 CHECK(expect_dims(r.diff_bias_desc, {L, D, G + extra_bias, DHC}));
238 CHECK(expect_dims(r.diff_dst_layer_desc, {T, N, DLC}, false));
239 CHECK(expect_dims(r.diff_dst_iter_desc, {L, D, N, DIC}));
240 CHECK(expect_dims(r.diff_dst_iter_c_desc, {L, D, N, DHC}));
241 }
242
243 return success;
244 }
245
rnn_common_fwd_desc_init(dnnl_rnn_desc_t * rnn_desc,prop_kind_t prop_kind,dnnl_alg_kind_t cell_kind,const rnn_direction_t direction,const memory_desc_t * src_layer_desc,const memory_desc_t * src_iter_desc,const memory_desc_t * src_iter_c_desc,const memory_desc_t * weights_layer_desc,const memory_desc_t * weights_iter_desc,const memory_desc_t * weights_peephole_desc,const memory_desc_t * weights_projection_desc,const memory_desc_t * bias_desc,const memory_desc_t * dst_layer_desc,const memory_desc_t * dst_iter_desc,const memory_desc_t * dst_iter_c_desc,unsigned flags,dnnl_alg_kind_t activation=dnnl_alg_kind_undef,float alpha=0.0f,float beta=0.0f)246 status_t rnn_common_fwd_desc_init(dnnl_rnn_desc_t *rnn_desc,
247 prop_kind_t prop_kind, dnnl_alg_kind_t cell_kind,
248 const rnn_direction_t direction, const memory_desc_t *src_layer_desc,
249 const memory_desc_t *src_iter_desc,
250 const memory_desc_t *src_iter_c_desc,
251 const memory_desc_t *weights_layer_desc,
252 const memory_desc_t *weights_iter_desc,
253 const memory_desc_t *weights_peephole_desc,
254 const memory_desc_t *weights_projection_desc,
255 const memory_desc_t *bias_desc, const memory_desc_t *dst_layer_desc,
256 const memory_desc_t *dst_iter_desc,
257 const memory_desc_t *dst_iter_c_desc, unsigned flags,
258 dnnl_alg_kind_t activation = dnnl_alg_kind_undef, float alpha = 0.0f,
259 float beta = 0.0f) {
260
261 // check that a supported cell kind has been passed
262 bool args_ok = one_of(cell_kind, dnnl_vanilla_rnn, dnnl_vanilla_lstm,
263 dnnl_vanilla_gru, dnnl_lbr_gru);
264 if (!args_ok) return invalid_arguments;
265
266 // check that all mandatory parameters are non-null
267 args_ok = args_ok
268 && !any_null(src_layer_desc, weights_layer_desc, weights_iter_desc,
269 dst_layer_desc);
270 if (!args_ok) return invalid_arguments;
271
272 if (cell_kind == dnnl_vanilla_rnn) {
273 using namespace alg_kind;
274 args_ok = args_ok
275 && one_of(activation, eltwise_relu, eltwise_tanh,
276 eltwise_logistic);
277 if (!args_ok) return invalid_arguments;
278 }
279
280 if (cell_kind == dnnl_vanilla_lstm) {
281 // check if optional *_iter is provided then *_iter_c is provided too
282 args_ok = args_ok && xnor_md(src_iter_desc, src_iter_c_desc)
283 && xnor_md(dst_iter_desc, dst_iter_c_desc);
284 if (!args_ok) return invalid_arguments;
285 }
286
287 CHECK(check_runtime_dims_or_strides({src_layer_desc, src_iter_desc,
288 src_iter_c_desc, weights_layer_desc, weights_iter_desc,
289 weights_peephole_desc, weights_projection_desc, bias_desc,
290 dst_layer_desc, dst_iter_desc, dst_iter_c_desc}));
291
292 // Create the descriptor
293 auto rd = rnn_desc_t();
294
295 rd.primitive_kind = primitive_kind::rnn;
296 rd.prop_kind = prop_kind;
297 rd.cell_kind = cell_kind;
298 rd.direction = direction;
299 maybe_init_md(rd.src_layer_desc, src_layer_desc);
300 maybe_init_md(rd.src_iter_desc, src_iter_desc);
301 maybe_init_md(rd.src_iter_c_desc, src_iter_c_desc);
302 maybe_init_md(rd.weights_layer_desc, weights_layer_desc);
303 maybe_init_md(rd.weights_iter_desc, weights_iter_desc);
304 maybe_init_md(rd.weights_peephole_desc, weights_peephole_desc);
305 maybe_init_md(rd.weights_projection_desc, weights_projection_desc);
306 maybe_init_md(rd.bias_desc, bias_desc);
307 maybe_init_md(rd.dst_layer_desc, dst_layer_desc);
308 maybe_init_md(rd.dst_iter_desc, dst_iter_desc);
309 maybe_init_md(rd.dst_iter_c_desc, dst_iter_c_desc);
310
311 rd.flags = flags;
312 rd.activation_kind = activation;
313 rd.alpha = alpha;
314 rd.beta = beta;
315
316 CHECK(check_data_type_consistency_fwd(rd));
317 CHECK(check_dim_consistency(rd));
318
319 *rnn_desc = rd;
320
321 return success;
322 }
323
rnn_common_bwd_desc_init(dnnl_rnn_desc_t * rnn_desc,dnnl_prop_kind_t prop_kind,dnnl_alg_kind_t cell_kind,const dnnl_rnn_direction_t direction,const dnnl_memory_desc_t * src_layer_desc,const dnnl_memory_desc_t * src_iter_desc,const dnnl_memory_desc_t * src_iter_c_desc,const dnnl_memory_desc_t * weights_layer_desc,const dnnl_memory_desc_t * weights_iter_desc,const dnnl_memory_desc_t * weights_peephole_desc,const dnnl_memory_desc_t * weights_projection_desc,const dnnl_memory_desc_t * bias_desc,const dnnl_memory_desc_t * dst_layer_desc,const dnnl_memory_desc_t * dst_iter_desc,const dnnl_memory_desc_t * dst_iter_c_desc,const dnnl_memory_desc_t * diff_src_layer_desc,const dnnl_memory_desc_t * diff_src_iter_desc,const dnnl_memory_desc_t * diff_src_iter_c_desc,const dnnl_memory_desc_t * diff_weights_layer_desc,const dnnl_memory_desc_t * diff_weights_iter_desc,const dnnl_memory_desc_t * diff_weights_peephole_desc,const dnnl_memory_desc_t * diff_weights_projection_desc,const dnnl_memory_desc_t * diff_bias_desc,const dnnl_memory_desc_t * diff_dst_layer_desc,const dnnl_memory_desc_t * diff_dst_iter_desc,const dnnl_memory_desc_t * diff_dst_iter_c_desc,unsigned flags,dnnl_alg_kind_t activation=dnnl_alg_kind_undef,float alpha=0.0f,float beta=0.0f)324 status_t rnn_common_bwd_desc_init(dnnl_rnn_desc_t *rnn_desc,
325 dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t cell_kind,
326 const dnnl_rnn_direction_t direction,
327 const dnnl_memory_desc_t *src_layer_desc,
328 const dnnl_memory_desc_t *src_iter_desc,
329 const dnnl_memory_desc_t *src_iter_c_desc,
330 const dnnl_memory_desc_t *weights_layer_desc,
331 const dnnl_memory_desc_t *weights_iter_desc,
332 const dnnl_memory_desc_t *weights_peephole_desc,
333 const dnnl_memory_desc_t *weights_projection_desc,
334 const dnnl_memory_desc_t *bias_desc,
335 const dnnl_memory_desc_t *dst_layer_desc,
336 const dnnl_memory_desc_t *dst_iter_desc,
337 const dnnl_memory_desc_t *dst_iter_c_desc,
338 const dnnl_memory_desc_t *diff_src_layer_desc,
339 const dnnl_memory_desc_t *diff_src_iter_desc,
340 const dnnl_memory_desc_t *diff_src_iter_c_desc,
341 const dnnl_memory_desc_t *diff_weights_layer_desc,
342 const dnnl_memory_desc_t *diff_weights_iter_desc,
343 const dnnl_memory_desc_t *diff_weights_peephole_desc,
344 const dnnl_memory_desc_t *diff_weights_projection_desc,
345 const dnnl_memory_desc_t *diff_bias_desc,
346 const dnnl_memory_desc_t *diff_dst_layer_desc,
347 const dnnl_memory_desc_t *diff_dst_iter_desc,
348 const dnnl_memory_desc_t *diff_dst_iter_c_desc, unsigned flags,
349 dnnl_alg_kind_t activation = dnnl_alg_kind_undef, float alpha = 0.0f,
350 float beta = 0.0f) {
351
352 // check that a supported cell kind has been passed
353 bool args_ok = one_of(cell_kind, dnnl_vanilla_rnn, dnnl_vanilla_lstm,
354 dnnl_vanilla_gru, dnnl_lbr_gru);
355 if (!args_ok) return invalid_arguments;
356
357 // check that all mandatory parameters are non-null
358 args_ok = args_ok
359 && !any_null(src_layer_desc, weights_layer_desc, weights_iter_desc,
360 dst_layer_desc, diff_src_layer_desc,
361 diff_weights_layer_desc, diff_weights_iter_desc,
362 diff_dst_layer_desc);
363 if (!args_ok) return invalid_arguments;
364
365 if (cell_kind == dnnl_vanilla_rnn) {
366 using namespace alg_kind;
367 args_ok = args_ok
368 && one_of(activation, eltwise_relu, eltwise_tanh,
369 eltwise_logistic);
370 if (!args_ok) return invalid_arguments;
371 }
372
373 if (cell_kind == dnnl_vanilla_lstm) {
374 // check if optional *_iter is provided then *_iter_c is provided too
375 args_ok = args_ok && xnor_md(src_iter_desc, src_iter_c_desc)
376 && xnor_md(dst_iter_desc, dst_iter_c_desc);
377 if (!args_ok) return invalid_arguments;
378 }
379
380 // check if optional md is provided then diff_md is provided too
381 args_ok = args_ok && xnor_md(bias_desc, diff_bias_desc)
382 && xnor_md(weights_peephole_desc, diff_weights_peephole_desc)
383 && xnor_md(weights_projection_desc, diff_weights_projection_desc)
384 && xnor_md(src_iter_desc, diff_src_iter_desc)
385 && xnor_md(src_iter_c_desc, diff_src_iter_c_desc)
386 && xnor_md(dst_iter_desc, diff_dst_iter_desc)
387 && xnor_md(dst_iter_c_desc, diff_dst_iter_c_desc);
388 if (!args_ok) return invalid_arguments;
389
390 CHECK(check_runtime_dims_or_strides({src_layer_desc, src_iter_desc,
391 src_iter_c_desc, weights_layer_desc, weights_iter_desc,
392 weights_peephole_desc, weights_projection_desc, bias_desc,
393 dst_layer_desc, dst_iter_desc, dst_iter_c_desc, diff_src_layer_desc,
394 diff_src_iter_desc, diff_src_iter_c_desc, diff_weights_layer_desc,
395 diff_weights_iter_desc, diff_weights_peephole_desc,
396 diff_weights_projection_desc, diff_bias_desc, diff_dst_layer_desc,
397 diff_dst_iter_desc, diff_dst_iter_c_desc}));
398
399 auto rd = dnnl_rnn_desc_t();
400
401 rd.primitive_kind = primitive_kind::rnn;
402 rd.prop_kind = prop_kind;
403 rd.cell_kind = cell_kind;
404 rd.direction = direction;
405
406 maybe_init_md(rd.src_layer_desc, src_layer_desc);
407 maybe_init_md(rd.src_iter_desc, src_iter_desc);
408 maybe_init_md(rd.src_iter_c_desc, src_iter_c_desc);
409 maybe_init_md(rd.weights_layer_desc, weights_layer_desc);
410 maybe_init_md(rd.weights_iter_desc, weights_iter_desc);
411 maybe_init_md(rd.weights_peephole_desc, weights_peephole_desc);
412 maybe_init_md(rd.weights_projection_desc, weights_projection_desc);
413 maybe_init_md(rd.bias_desc, bias_desc);
414 maybe_init_md(rd.dst_layer_desc, dst_layer_desc);
415 maybe_init_md(rd.dst_iter_desc, dst_iter_desc);
416 maybe_init_md(rd.dst_iter_c_desc, dst_iter_c_desc);
417 maybe_init_md(rd.diff_src_layer_desc, diff_src_layer_desc);
418 maybe_init_md(rd.diff_src_iter_desc, diff_src_iter_desc);
419 maybe_init_md(rd.diff_src_iter_c_desc, diff_src_iter_c_desc);
420 maybe_init_md(rd.diff_weights_layer_desc, diff_weights_layer_desc);
421 maybe_init_md(rd.diff_weights_iter_desc, diff_weights_iter_desc);
422 maybe_init_md(rd.diff_weights_peephole_desc, diff_weights_peephole_desc);
423 maybe_init_md(
424 rd.diff_weights_projection_desc, diff_weights_projection_desc);
425 maybe_init_md(rd.diff_bias_desc, diff_bias_desc);
426 maybe_init_md(rd.diff_dst_layer_desc, diff_dst_layer_desc);
427 maybe_init_md(rd.diff_dst_iter_desc, diff_dst_iter_desc);
428 maybe_init_md(rd.diff_dst_iter_c_desc, diff_dst_iter_c_desc);
429
430 rd.flags = flags;
431 rd.activation_kind = activation;
432 rd.alpha = alpha;
433 rd.beta = beta;
434
435 CHECK(check_data_type_consistency_fwd(rd));
436 CHECK(check_data_type_consistency_bwd(rd));
437
438 CHECK(check_dim_consistency(rd));
439
440 *rnn_desc = rd;
441
442 return success;
443 }
444
445 } // namespace
446
447 /* Public C Api */
448
dnnl_vanilla_rnn_forward_desc_init(dnnl_rnn_desc_t * rnn_desc,dnnl_prop_kind_t prop_kind,const dnnl_alg_kind_t activation,const dnnl_rnn_direction_t direction,const dnnl_memory_desc_t * src_layer_desc,const dnnl_memory_desc_t * src_iter_desc,const dnnl_memory_desc_t * weights_layer_desc,const dnnl_memory_desc_t * weights_iter_desc,const dnnl_memory_desc_t * bias_desc,const dnnl_memory_desc_t * dst_layer_desc,const dnnl_memory_desc_t * dst_iter_desc,unsigned flags,float alpha,float beta)449 status_t dnnl_vanilla_rnn_forward_desc_init(dnnl_rnn_desc_t *rnn_desc,
450 dnnl_prop_kind_t prop_kind, const dnnl_alg_kind_t activation,
451 const dnnl_rnn_direction_t direction,
452 const dnnl_memory_desc_t *src_layer_desc,
453 const dnnl_memory_desc_t *src_iter_desc,
454 const dnnl_memory_desc_t *weights_layer_desc,
455 const dnnl_memory_desc_t *weights_iter_desc,
456 const dnnl_memory_desc_t *bias_desc,
457 const dnnl_memory_desc_t *dst_layer_desc,
458 const dnnl_memory_desc_t *dst_iter_desc, unsigned flags, float alpha,
459 float beta) {
460 status_t st = rnn_common_fwd_desc_init(rnn_desc, prop_kind,
461 dnnl_vanilla_rnn, direction, src_layer_desc, src_iter_desc, nullptr,
462 weights_layer_desc, weights_iter_desc, nullptr, nullptr, bias_desc,
463 dst_layer_desc, dst_iter_desc, nullptr, flags, activation, alpha,
464 beta);
465 return st;
466 }
467
dnnl_lstm_forward_desc_init(dnnl_rnn_desc_t * rnn_desc,dnnl_prop_kind_t prop_kind,dnnl_rnn_direction_t direction,const dnnl_memory_desc_t * src_layer_desc,const dnnl_memory_desc_t * src_iter_desc,const dnnl_memory_desc_t * src_iter_c_desc,const dnnl_memory_desc_t * weights_layer_desc,const dnnl_memory_desc_t * weights_iter_desc,const dnnl_memory_desc_t * bias_desc,const dnnl_memory_desc_t * dst_layer_desc,const dnnl_memory_desc_t * dst_iter_desc,const dnnl_memory_desc_t * dst_iter_c_desc,unsigned flags)468 status_t dnnl_lstm_forward_desc_init(dnnl_rnn_desc_t *rnn_desc,
469 dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
470 const dnnl_memory_desc_t *src_layer_desc,
471 const dnnl_memory_desc_t *src_iter_desc,
472 const dnnl_memory_desc_t *src_iter_c_desc,
473 const dnnl_memory_desc_t *weights_layer_desc,
474 const dnnl_memory_desc_t *weights_iter_desc,
475 const dnnl_memory_desc_t *bias_desc,
476 const dnnl_memory_desc_t *dst_layer_desc,
477 const dnnl_memory_desc_t *dst_iter_desc,
478 const dnnl_memory_desc_t *dst_iter_c_desc, unsigned flags) {
479 return dnnl_lstm_forward_desc_init_v3(rnn_desc, prop_kind, direction,
480 src_layer_desc, src_iter_desc, src_iter_c_desc, weights_layer_desc,
481 weights_iter_desc, nullptr, nullptr, bias_desc, dst_layer_desc,
482 dst_iter_desc, dst_iter_c_desc, flags);
483 }
484
dnnl_lstm_forward_desc_init_v2(dnnl_rnn_desc_t * rnn_desc,dnnl_prop_kind_t prop_kind,dnnl_rnn_direction_t direction,const dnnl_memory_desc_t * src_layer_desc,const dnnl_memory_desc_t * src_iter_desc,const dnnl_memory_desc_t * src_iter_c_desc,const dnnl_memory_desc_t * weights_layer_desc,const dnnl_memory_desc_t * weights_iter_desc,const dnnl_memory_desc_t * weights_peephole_desc,const dnnl_memory_desc_t * bias_desc,const dnnl_memory_desc_t * dst_layer_desc,const dnnl_memory_desc_t * dst_iter_desc,const dnnl_memory_desc_t * dst_iter_c_desc,unsigned flags)485 status_t dnnl_lstm_forward_desc_init_v2(dnnl_rnn_desc_t *rnn_desc,
486 dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
487 const dnnl_memory_desc_t *src_layer_desc,
488 const dnnl_memory_desc_t *src_iter_desc,
489 const dnnl_memory_desc_t *src_iter_c_desc,
490 const dnnl_memory_desc_t *weights_layer_desc,
491 const dnnl_memory_desc_t *weights_iter_desc,
492 const dnnl_memory_desc_t *weights_peephole_desc,
493 const dnnl_memory_desc_t *bias_desc,
494 const dnnl_memory_desc_t *dst_layer_desc,
495 const dnnl_memory_desc_t *dst_iter_desc,
496 const dnnl_memory_desc_t *dst_iter_c_desc, unsigned flags) {
497 return dnnl_lstm_forward_desc_init_v3(rnn_desc, prop_kind, direction,
498 src_layer_desc, src_iter_desc, src_iter_c_desc, weights_layer_desc,
499 weights_iter_desc, weights_peephole_desc, nullptr, bias_desc,
500 dst_layer_desc, dst_iter_desc, dst_iter_c_desc, flags);
501 }
502
dnnl_lstm_forward_desc_init_v3(dnnl_rnn_desc_t * rnn_desc,dnnl_prop_kind_t prop_kind,dnnl_rnn_direction_t direction,const dnnl_memory_desc_t * src_layer_desc,const dnnl_memory_desc_t * src_iter_desc,const dnnl_memory_desc_t * src_iter_c_desc,const dnnl_memory_desc_t * weights_layer_desc,const dnnl_memory_desc_t * weights_iter_desc,const dnnl_memory_desc_t * weights_peephole_desc,const dnnl_memory_desc_t * weights_projection_desc,const dnnl_memory_desc_t * bias_desc,const dnnl_memory_desc_t * dst_layer_desc,const dnnl_memory_desc_t * dst_iter_desc,const dnnl_memory_desc_t * dst_iter_c_desc,unsigned flags)503 status_t dnnl_lstm_forward_desc_init_v3(dnnl_rnn_desc_t *rnn_desc,
504 dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
505 const dnnl_memory_desc_t *src_layer_desc,
506 const dnnl_memory_desc_t *src_iter_desc,
507 const dnnl_memory_desc_t *src_iter_c_desc,
508 const dnnl_memory_desc_t *weights_layer_desc,
509 const dnnl_memory_desc_t *weights_iter_desc,
510 const dnnl_memory_desc_t *weights_peephole_desc,
511 const dnnl_memory_desc_t *weights_projection_desc,
512 const dnnl_memory_desc_t *bias_desc,
513 const dnnl_memory_desc_t *dst_layer_desc,
514 const dnnl_memory_desc_t *dst_iter_desc,
515 const dnnl_memory_desc_t *dst_iter_c_desc, unsigned flags) {
516 status_t st = rnn_common_fwd_desc_init(rnn_desc, prop_kind,
517 dnnl_vanilla_lstm, direction, src_layer_desc, src_iter_desc,
518 src_iter_c_desc, weights_layer_desc, weights_iter_desc,
519 weights_peephole_desc, weights_projection_desc, bias_desc,
520 dst_layer_desc, dst_iter_desc, dst_iter_c_desc, flags);
521 return st;
522 }
523
dnnl_gru_forward_desc_init(dnnl_rnn_desc_t * rnn_desc,dnnl_prop_kind_t prop_kind,dnnl_rnn_direction_t direction,const dnnl_memory_desc_t * src_layer_desc,const dnnl_memory_desc_t * src_iter_desc,const dnnl_memory_desc_t * weights_layer_desc,const dnnl_memory_desc_t * weights_iter_desc,const dnnl_memory_desc_t * bias_desc,const dnnl_memory_desc_t * dst_layer_desc,const dnnl_memory_desc_t * dst_iter_desc,unsigned flags)524 status_t dnnl_gru_forward_desc_init(dnnl_rnn_desc_t *rnn_desc,
525 dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
526 const dnnl_memory_desc_t *src_layer_desc,
527 const dnnl_memory_desc_t *src_iter_desc,
528 const dnnl_memory_desc_t *weights_layer_desc,
529 const dnnl_memory_desc_t *weights_iter_desc,
530 const dnnl_memory_desc_t *bias_desc,
531 const dnnl_memory_desc_t *dst_layer_desc,
532 const dnnl_memory_desc_t *dst_iter_desc, unsigned flags) {
533 status_t st = rnn_common_fwd_desc_init(rnn_desc, prop_kind,
534 dnnl_vanilla_gru, direction, src_layer_desc, src_iter_desc, nullptr,
535 weights_layer_desc, weights_iter_desc, nullptr, nullptr, bias_desc,
536 dst_layer_desc, dst_iter_desc, nullptr, flags);
537 return st;
538 }
539
dnnl_lbr_gru_forward_desc_init(dnnl_rnn_desc_t * rnn_desc,dnnl_prop_kind_t prop_kind,dnnl_rnn_direction_t direction,const dnnl_memory_desc_t * src_layer_desc,const dnnl_memory_desc_t * src_iter_desc,const dnnl_memory_desc_t * weights_layer_desc,const dnnl_memory_desc_t * weights_iter_desc,const dnnl_memory_desc_t * bias_desc,const dnnl_memory_desc_t * dst_layer_desc,const dnnl_memory_desc_t * dst_iter_desc,unsigned flags)540 status_t dnnl_lbr_gru_forward_desc_init(dnnl_rnn_desc_t *rnn_desc,
541 dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
542 const dnnl_memory_desc_t *src_layer_desc,
543 const dnnl_memory_desc_t *src_iter_desc,
544 const dnnl_memory_desc_t *weights_layer_desc,
545 const dnnl_memory_desc_t *weights_iter_desc,
546 const dnnl_memory_desc_t *bias_desc,
547 const dnnl_memory_desc_t *dst_layer_desc,
548 const dnnl_memory_desc_t *dst_iter_desc, unsigned flags) {
549 status_t st = rnn_common_fwd_desc_init(rnn_desc, prop_kind, dnnl_lbr_gru,
550 direction, src_layer_desc, src_iter_desc, nullptr,
551 weights_layer_desc, weights_iter_desc, nullptr, nullptr, bias_desc,
552 dst_layer_desc, dst_iter_desc, nullptr, flags);
553 return st;
554 }
555
dnnl_vanilla_rnn_backward_desc_init(dnnl_rnn_desc_t * rnn_desc,dnnl_prop_kind_t prop_kind,const dnnl_alg_kind_t activation,const dnnl_rnn_direction_t direction,const dnnl_memory_desc_t * src_layer_desc,const dnnl_memory_desc_t * src_iter_desc,const dnnl_memory_desc_t * weights_layer_desc,const dnnl_memory_desc_t * weights_iter_desc,const dnnl_memory_desc_t * bias_desc,const dnnl_memory_desc_t * dst_layer_desc,const dnnl_memory_desc_t * dst_iter_desc,const dnnl_memory_desc_t * diff_src_layer_desc,const dnnl_memory_desc_t * diff_src_iter_desc,const dnnl_memory_desc_t * diff_weights_layer_desc,const dnnl_memory_desc_t * diff_weights_iter_desc,const dnnl_memory_desc_t * diff_bias_desc,const dnnl_memory_desc_t * diff_dst_layer_desc,const dnnl_memory_desc_t * diff_dst_iter_desc,unsigned flags,float alpha,float beta)556 status_t dnnl_vanilla_rnn_backward_desc_init(dnnl_rnn_desc_t *rnn_desc,
557 dnnl_prop_kind_t prop_kind, const dnnl_alg_kind_t activation,
558 const dnnl_rnn_direction_t direction,
559 const dnnl_memory_desc_t *src_layer_desc,
560 const dnnl_memory_desc_t *src_iter_desc,
561 const dnnl_memory_desc_t *weights_layer_desc,
562 const dnnl_memory_desc_t *weights_iter_desc,
563 const dnnl_memory_desc_t *bias_desc,
564 const dnnl_memory_desc_t *dst_layer_desc,
565 const dnnl_memory_desc_t *dst_iter_desc,
566 const dnnl_memory_desc_t *diff_src_layer_desc,
567 const dnnl_memory_desc_t *diff_src_iter_desc,
568 const dnnl_memory_desc_t *diff_weights_layer_desc,
569 const dnnl_memory_desc_t *diff_weights_iter_desc,
570 const dnnl_memory_desc_t *diff_bias_desc,
571 const dnnl_memory_desc_t *diff_dst_layer_desc,
572 const dnnl_memory_desc_t *diff_dst_iter_desc, unsigned flags,
573 float alpha, float beta) {
574 status_t st = rnn_common_bwd_desc_init(rnn_desc, prop_kind,
575 dnnl_vanilla_rnn, direction, src_layer_desc, src_iter_desc, nullptr,
576 weights_layer_desc, weights_iter_desc, nullptr, nullptr, bias_desc,
577 dst_layer_desc, dst_iter_desc, nullptr, diff_src_layer_desc,
578 diff_src_iter_desc, nullptr, diff_weights_layer_desc,
579 diff_weights_iter_desc, nullptr, nullptr, diff_bias_desc,
580 diff_dst_layer_desc, diff_dst_iter_desc, nullptr, flags, activation,
581 alpha, beta);
582 return st;
583 }
584
dnnl_lstm_backward_desc_init(dnnl_rnn_desc_t * rnn_desc,dnnl_prop_kind_t prop_kind,dnnl_rnn_direction_t direction,const dnnl_memory_desc_t * src_layer_desc,const dnnl_memory_desc_t * src_iter_desc,const dnnl_memory_desc_t * src_iter_c_desc,const dnnl_memory_desc_t * weights_layer_desc,const dnnl_memory_desc_t * weights_iter_desc,const dnnl_memory_desc_t * bias_desc,const dnnl_memory_desc_t * dst_layer_desc,const dnnl_memory_desc_t * dst_iter_desc,const dnnl_memory_desc_t * dst_iter_c_desc,const dnnl_memory_desc_t * diff_src_layer_desc,const dnnl_memory_desc_t * diff_src_iter_desc,const dnnl_memory_desc_t * diff_src_iter_c_desc,const dnnl_memory_desc_t * diff_weights_layer_desc,const dnnl_memory_desc_t * diff_weights_iter_desc,const dnnl_memory_desc_t * diff_bias_desc,const dnnl_memory_desc_t * diff_dst_layer_desc,const dnnl_memory_desc_t * diff_dst_iter_desc,const dnnl_memory_desc_t * diff_dst_iter_c_desc,unsigned flags)585 status_t dnnl_lstm_backward_desc_init(dnnl_rnn_desc_t *rnn_desc,
586 dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
587 const dnnl_memory_desc_t *src_layer_desc,
588 const dnnl_memory_desc_t *src_iter_desc,
589 const dnnl_memory_desc_t *src_iter_c_desc,
590 const dnnl_memory_desc_t *weights_layer_desc,
591 const dnnl_memory_desc_t *weights_iter_desc,
592 const dnnl_memory_desc_t *bias_desc,
593 const dnnl_memory_desc_t *dst_layer_desc,
594 const dnnl_memory_desc_t *dst_iter_desc,
595 const dnnl_memory_desc_t *dst_iter_c_desc,
596 const dnnl_memory_desc_t *diff_src_layer_desc,
597 const dnnl_memory_desc_t *diff_src_iter_desc,
598 const dnnl_memory_desc_t *diff_src_iter_c_desc,
599 const dnnl_memory_desc_t *diff_weights_layer_desc,
600 const dnnl_memory_desc_t *diff_weights_iter_desc,
601 const dnnl_memory_desc_t *diff_bias_desc,
602 const dnnl_memory_desc_t *diff_dst_layer_desc,
603 const dnnl_memory_desc_t *diff_dst_iter_desc,
604 const dnnl_memory_desc_t *diff_dst_iter_c_desc, unsigned flags) {
605 return dnnl_lstm_backward_desc_init_v3(rnn_desc, prop_kind, direction,
606 src_layer_desc, src_iter_desc, src_iter_c_desc, weights_layer_desc,
607 weights_iter_desc, nullptr, nullptr, bias_desc, dst_layer_desc,
608 dst_iter_desc, dst_iter_c_desc, diff_src_layer_desc,
609 diff_src_iter_desc, diff_src_iter_c_desc, diff_weights_layer_desc,
610 diff_weights_iter_desc, nullptr, nullptr, diff_bias_desc,
611 diff_dst_layer_desc, diff_dst_iter_desc, diff_dst_iter_c_desc,
612 flags);
613 }
614
dnnl_lstm_backward_desc_init_v2(dnnl_rnn_desc_t * rnn_desc,dnnl_prop_kind_t prop_kind,dnnl_rnn_direction_t direction,const dnnl_memory_desc_t * src_layer_desc,const dnnl_memory_desc_t * src_iter_desc,const dnnl_memory_desc_t * src_iter_c_desc,const dnnl_memory_desc_t * weights_layer_desc,const dnnl_memory_desc_t * weights_iter_desc,const dnnl_memory_desc_t * weights_peephole_desc,const dnnl_memory_desc_t * bias_desc,const dnnl_memory_desc_t * dst_layer_desc,const dnnl_memory_desc_t * dst_iter_desc,const dnnl_memory_desc_t * dst_iter_c_desc,const dnnl_memory_desc_t * diff_src_layer_desc,const dnnl_memory_desc_t * diff_src_iter_desc,const dnnl_memory_desc_t * diff_src_iter_c_desc,const dnnl_memory_desc_t * diff_weights_layer_desc,const dnnl_memory_desc_t * diff_weights_iter_desc,const dnnl_memory_desc_t * diff_weights_peephole_desc,const dnnl_memory_desc_t * diff_bias_desc,const dnnl_memory_desc_t * diff_dst_layer_desc,const dnnl_memory_desc_t * diff_dst_iter_desc,const dnnl_memory_desc_t * diff_dst_iter_c_desc,unsigned flags)615 status_t dnnl_lstm_backward_desc_init_v2(dnnl_rnn_desc_t *rnn_desc,
616 dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
617 const dnnl_memory_desc_t *src_layer_desc,
618 const dnnl_memory_desc_t *src_iter_desc,
619 const dnnl_memory_desc_t *src_iter_c_desc,
620 const dnnl_memory_desc_t *weights_layer_desc,
621 const dnnl_memory_desc_t *weights_iter_desc,
622 const dnnl_memory_desc_t *weights_peephole_desc,
623 const dnnl_memory_desc_t *bias_desc,
624 const dnnl_memory_desc_t *dst_layer_desc,
625 const dnnl_memory_desc_t *dst_iter_desc,
626 const dnnl_memory_desc_t *dst_iter_c_desc,
627 const dnnl_memory_desc_t *diff_src_layer_desc,
628 const dnnl_memory_desc_t *diff_src_iter_desc,
629 const dnnl_memory_desc_t *diff_src_iter_c_desc,
630 const dnnl_memory_desc_t *diff_weights_layer_desc,
631 const dnnl_memory_desc_t *diff_weights_iter_desc,
632 const dnnl_memory_desc_t *diff_weights_peephole_desc,
633 const dnnl_memory_desc_t *diff_bias_desc,
634 const dnnl_memory_desc_t *diff_dst_layer_desc,
635 const dnnl_memory_desc_t *diff_dst_iter_desc,
636 const dnnl_memory_desc_t *diff_dst_iter_c_desc, unsigned flags) {
637 return dnnl_lstm_backward_desc_init_v3(rnn_desc, prop_kind, direction,
638 src_layer_desc, src_iter_desc, src_iter_c_desc, weights_layer_desc,
639 weights_iter_desc, weights_peephole_desc, nullptr, bias_desc,
640 dst_layer_desc, dst_iter_desc, dst_iter_c_desc, diff_src_layer_desc,
641 diff_src_iter_desc, diff_src_iter_c_desc, diff_weights_layer_desc,
642 diff_weights_iter_desc, diff_weights_peephole_desc, nullptr,
643 diff_bias_desc, diff_dst_layer_desc, diff_dst_iter_desc,
644 diff_dst_iter_c_desc, flags);
645 }
646
dnnl_lstm_backward_desc_init_v3(dnnl_rnn_desc_t * rnn_desc,dnnl_prop_kind_t prop_kind,dnnl_rnn_direction_t direction,const dnnl_memory_desc_t * src_layer_desc,const dnnl_memory_desc_t * src_iter_desc,const dnnl_memory_desc_t * src_iter_c_desc,const dnnl_memory_desc_t * weights_layer_desc,const dnnl_memory_desc_t * weights_iter_desc,const dnnl_memory_desc_t * weights_peephole_desc,const dnnl_memory_desc_t * weights_projection_desc,const dnnl_memory_desc_t * bias_desc,const dnnl_memory_desc_t * dst_layer_desc,const dnnl_memory_desc_t * dst_iter_desc,const dnnl_memory_desc_t * dst_iter_c_desc,const dnnl_memory_desc_t * diff_src_layer_desc,const dnnl_memory_desc_t * diff_src_iter_desc,const dnnl_memory_desc_t * diff_src_iter_c_desc,const dnnl_memory_desc_t * diff_weights_layer_desc,const dnnl_memory_desc_t * diff_weights_iter_desc,const dnnl_memory_desc_t * diff_weights_peephole_desc,const dnnl_memory_desc_t * diff_weights_projection_desc,const dnnl_memory_desc_t * diff_bias_desc,const dnnl_memory_desc_t * diff_dst_layer_desc,const dnnl_memory_desc_t * diff_dst_iter_desc,const dnnl_memory_desc_t * diff_dst_iter_c_desc,unsigned flags)647 status_t dnnl_lstm_backward_desc_init_v3(dnnl_rnn_desc_t *rnn_desc,
648 dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
649 const dnnl_memory_desc_t *src_layer_desc,
650 const dnnl_memory_desc_t *src_iter_desc,
651 const dnnl_memory_desc_t *src_iter_c_desc,
652 const dnnl_memory_desc_t *weights_layer_desc,
653 const dnnl_memory_desc_t *weights_iter_desc,
654 const dnnl_memory_desc_t *weights_peephole_desc,
655 const dnnl_memory_desc_t *weights_projection_desc,
656 const dnnl_memory_desc_t *bias_desc,
657 const dnnl_memory_desc_t *dst_layer_desc,
658 const dnnl_memory_desc_t *dst_iter_desc,
659 const dnnl_memory_desc_t *dst_iter_c_desc,
660 const dnnl_memory_desc_t *diff_src_layer_desc,
661 const dnnl_memory_desc_t *diff_src_iter_desc,
662 const dnnl_memory_desc_t *diff_src_iter_c_desc,
663 const dnnl_memory_desc_t *diff_weights_layer_desc,
664 const dnnl_memory_desc_t *diff_weights_iter_desc,
665 const dnnl_memory_desc_t *diff_weights_peephole_desc,
666 const dnnl_memory_desc_t *diff_weights_projection_desc,
667 const dnnl_memory_desc_t *diff_bias_desc,
668 const dnnl_memory_desc_t *diff_dst_layer_desc,
669 const dnnl_memory_desc_t *diff_dst_iter_desc,
670 const dnnl_memory_desc_t *diff_dst_iter_c_desc, unsigned flags) {
671 status_t st = rnn_common_bwd_desc_init(rnn_desc, prop_kind,
672 dnnl_vanilla_lstm, direction, src_layer_desc, src_iter_desc,
673 src_iter_c_desc, weights_layer_desc, weights_iter_desc,
674 weights_peephole_desc, weights_projection_desc, bias_desc,
675 dst_layer_desc, dst_iter_desc, dst_iter_c_desc, diff_src_layer_desc,
676 diff_src_iter_desc, diff_src_iter_c_desc, diff_weights_layer_desc,
677 diff_weights_iter_desc, diff_weights_peephole_desc,
678 diff_weights_projection_desc, diff_bias_desc, diff_dst_layer_desc,
679 diff_dst_iter_desc, diff_dst_iter_c_desc, flags);
680 return st;
681 }
682
dnnl_gru_backward_desc_init(dnnl_rnn_desc_t * rnn_desc,dnnl_prop_kind_t prop_kind,dnnl_rnn_direction_t direction,const dnnl_memory_desc_t * src_layer_desc,const dnnl_memory_desc_t * src_iter_desc,const dnnl_memory_desc_t * weights_layer_desc,const dnnl_memory_desc_t * weights_iter_desc,const dnnl_memory_desc_t * bias_desc,const dnnl_memory_desc_t * dst_layer_desc,const dnnl_memory_desc_t * dst_iter_desc,const dnnl_memory_desc_t * diff_src_layer_desc,const dnnl_memory_desc_t * diff_src_iter_desc,const dnnl_memory_desc_t * diff_weights_layer_desc,const dnnl_memory_desc_t * diff_weights_iter_desc,const dnnl_memory_desc_t * diff_bias_desc,const dnnl_memory_desc_t * diff_dst_layer_desc,const dnnl_memory_desc_t * diff_dst_iter_desc,unsigned flags)683 status_t dnnl_gru_backward_desc_init(dnnl_rnn_desc_t *rnn_desc,
684 dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
685 const dnnl_memory_desc_t *src_layer_desc,
686 const dnnl_memory_desc_t *src_iter_desc,
687 const dnnl_memory_desc_t *weights_layer_desc,
688 const dnnl_memory_desc_t *weights_iter_desc,
689 const dnnl_memory_desc_t *bias_desc,
690 const dnnl_memory_desc_t *dst_layer_desc,
691 const dnnl_memory_desc_t *dst_iter_desc,
692 const dnnl_memory_desc_t *diff_src_layer_desc,
693 const dnnl_memory_desc_t *diff_src_iter_desc,
694 const dnnl_memory_desc_t *diff_weights_layer_desc,
695 const dnnl_memory_desc_t *diff_weights_iter_desc,
696 const dnnl_memory_desc_t *diff_bias_desc,
697 const dnnl_memory_desc_t *diff_dst_layer_desc,
698 const dnnl_memory_desc_t *diff_dst_iter_desc, unsigned flags) {
699 status_t st = rnn_common_bwd_desc_init(rnn_desc, prop_kind,
700 dnnl_vanilla_gru, direction, src_layer_desc, src_iter_desc, nullptr,
701 weights_layer_desc, weights_iter_desc, nullptr, nullptr, bias_desc,
702 dst_layer_desc, dst_iter_desc, nullptr, diff_src_layer_desc,
703 diff_src_iter_desc, nullptr, diff_weights_layer_desc,
704 diff_weights_iter_desc, nullptr, nullptr, diff_bias_desc,
705 diff_dst_layer_desc, diff_dst_iter_desc, nullptr, flags);
706 return st;
707 }
708
dnnl_lbr_gru_backward_desc_init(dnnl_rnn_desc_t * rnn_desc,dnnl_prop_kind_t prop_kind,dnnl_rnn_direction_t direction,const dnnl_memory_desc_t * src_layer_desc,const dnnl_memory_desc_t * src_iter_desc,const dnnl_memory_desc_t * weights_layer_desc,const dnnl_memory_desc_t * weights_iter_desc,const dnnl_memory_desc_t * bias_desc,const dnnl_memory_desc_t * dst_layer_desc,const dnnl_memory_desc_t * dst_iter_desc,const dnnl_memory_desc_t * diff_src_layer_desc,const dnnl_memory_desc_t * diff_src_iter_desc,const dnnl_memory_desc_t * diff_weights_layer_desc,const dnnl_memory_desc_t * diff_weights_iter_desc,const dnnl_memory_desc_t * diff_bias_desc,const dnnl_memory_desc_t * diff_dst_layer_desc,const dnnl_memory_desc_t * diff_dst_iter_desc,unsigned flags)709 status_t dnnl_lbr_gru_backward_desc_init(dnnl_rnn_desc_t *rnn_desc,
710 dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
711 const dnnl_memory_desc_t *src_layer_desc,
712 const dnnl_memory_desc_t *src_iter_desc,
713 const dnnl_memory_desc_t *weights_layer_desc,
714 const dnnl_memory_desc_t *weights_iter_desc,
715 const dnnl_memory_desc_t *bias_desc,
716 const dnnl_memory_desc_t *dst_layer_desc,
717 const dnnl_memory_desc_t *dst_iter_desc,
718 const dnnl_memory_desc_t *diff_src_layer_desc,
719 const dnnl_memory_desc_t *diff_src_iter_desc,
720 const dnnl_memory_desc_t *diff_weights_layer_desc,
721 const dnnl_memory_desc_t *diff_weights_iter_desc,
722 const dnnl_memory_desc_t *diff_bias_desc,
723 const dnnl_memory_desc_t *diff_dst_layer_desc,
724 const dnnl_memory_desc_t *diff_dst_iter_desc, unsigned flags) {
725 status_t st = rnn_common_bwd_desc_init(rnn_desc, prop_kind, dnnl_lbr_gru,
726 direction, src_layer_desc, src_iter_desc, nullptr,
727 weights_layer_desc, weights_iter_desc, nullptr, nullptr, bias_desc,
728 dst_layer_desc, dst_iter_desc, nullptr, diff_src_layer_desc,
729 diff_src_iter_desc, nullptr, diff_weights_layer_desc,
730 diff_weights_iter_desc, nullptr, nullptr, diff_bias_desc,
731 diff_dst_layer_desc, diff_dst_iter_desc, nullptr, flags);
732 return st;
733 }
734