1 /*******************************************************************************
2 * Copyright 2016-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 /// @file
18 /// C API
19 
20 #ifndef ONEAPI_DNNL_DNNL_H
21 #define ONEAPI_DNNL_DNNL_H
22 
23 #include "oneapi/dnnl/dnnl_config.h"
24 #include "oneapi/dnnl/dnnl_types.h"
25 #include "oneapi/dnnl/dnnl_version.h"
26 
27 #ifdef __cplusplus
28 extern "C" {
29 #endif
30 
31 /// @addtogroup dnnl_api
32 /// @{
33 
34 /// @addtogroup dnnl_api_primitives
35 /// @{
36 
37 /// @addtogroup dnnl_api_primitives_common
38 /// @{
39 
40 /// Creates a primitive descriptor iterator.
41 ///
42 /// @param iterator Output primitive descriptor iterator.
43 /// @param op_desc Operation descriptor.
44 /// @param attr Primitive attributes (can be NULL).
45 /// @param engine Engine to use.
46 /// @param hint_forward_primitive_desc For backward propagation: primitive
47 ///     descriptor for a respective forward propagation primitive. Pass NULL
48 ///     for forward propagation.
49 /// @returns #dnnl_success on success and a status describing the error
50 ///     otherwise.
51 dnnl_status_t DNNL_API dnnl_primitive_desc_iterator_create(
52         dnnl_primitive_desc_iterator_t *iterator, const_dnnl_op_desc_t op_desc,
53         const_dnnl_primitive_attr_t attr, dnnl_engine_t engine,
54         const_dnnl_primitive_desc_t hint_forward_primitive_desc);
55 
56 /// Advances the primitive descriptor iterator to point to the next available
57 /// implementation.
58 ///
59 /// @param iterator A primitive descriptor iterator to advance.
60 /// @returns #dnnl_success on success and a status describing the error
61 ///     otherwise.
62 /// @returns #dnnl_iterator_ends if no more implementations available.
63 dnnl_status_t DNNL_API dnnl_primitive_desc_iterator_next(
64         dnnl_primitive_desc_iterator_t iterator);
65 
66 /// Fetches the current primitive descriptor from a primitive descriptor
67 /// iterator.
68 ///
69 /// @note
70 ///     The user is responsible for deleting the resulting primitive
71 ///     descriptor using dnnl_primitive_desc_destroy().
72 ///
73 /// @param iterator A primitive descriptor iterator.
74 /// @returns A primitive descriptor.
75 dnnl_primitive_desc_t DNNL_API dnnl_primitive_desc_iterator_fetch(
76         const_dnnl_primitive_desc_iterator_t iterator);
77 
78 /// Destroys a primitive descriptor iterator.
79 ///
80 /// @param iterator Primitive descriptor iterator to destroy.
81 /// @returns #dnnl_success on success and a status describing the error
82 ///     otherwise.
83 dnnl_status_t DNNL_API dnnl_primitive_desc_iterator_destroy(
84         dnnl_primitive_desc_iterator_t iterator);
85 
86 /// Creates a primitive descriptor. This function is equivalent to a sequence
87 /// of #dnnl_primitive_desc_iterator_create() and
88 /// #dnnl_primitive_desc_iterator_fetch(). In other words, the library will
89 /// pick the first suitable implementation.
90 ///
91 /// @param primitive_desc Output primitive descriptor.
92 /// @param op_desc Operation descriptor.
93 /// @param attr Primitive attributes (can be NULL).
94 /// @param engine Engine to use.
95 /// @param hint_forward_primitive_desc For backward propagation: primitive
96 ///     descriptor for a respective forward propagation primitive. Pass NULL
97 ///     for forward propagation.
98 /// @returns #dnnl_success on success and a status describing the error
99 ///     otherwise.
100 dnnl_status_t DNNL_API dnnl_primitive_desc_create(
101         dnnl_primitive_desc_t *primitive_desc, const_dnnl_op_desc_t op_desc,
102         const_dnnl_primitive_attr_t attr, dnnl_engine_t engine,
103         const_dnnl_primitive_desc_t hint_forward_primitive_desc);
104 
105 /// Clones a primitive descriptor. The resulting primitive descriptor must be
106 /// destroyed separately.
107 ///
108 /// @param primitive_desc Output primitive descriptor.
109 /// @param existing_primitive_desc Primitive descriptor to clone.
110 /// @returns #dnnl_success on success and a status describing the error
111 ///     otherwise.
112 dnnl_status_t DNNL_API dnnl_primitive_desc_clone(
113         dnnl_primitive_desc_t *primitive_desc,
114         const_dnnl_primitive_desc_t existing_primitive_desc);
115 
116 /// Returns a constant reference to the attributes of a primitive descriptor.
117 ///
118 /// @warning
119 ///     It is an error to destroy the resulting @p attr.
120 ///
121 /// @warning
122 ///     The lifetime of an @p attr is the same as that of a @p
123 ///     primitive_desc, so it is an error to use the @p attr once the @p
124 ///     primitive_desc has been destroyed.
125 ///
126 /// @param primitive_desc Primitive descriptor.
127 /// @param attr Output primitive attributes.
128 /// @returns #dnnl_success on success and a status describing the error
129 ///     otherwise.
130 dnnl_status_t DNNL_API dnnl_primitive_desc_get_attr(
131         const_dnnl_primitive_desc_t primitive_desc,
132         const_dnnl_primitive_attr_t *attr);
133 
134 /// Destroys a primitive descriptor.
135 ///
136 /// @param primitive_desc Primitive descriptor to destroy.
137 /// @returns #dnnl_success on success and a status describing the error
138 ///     otherwise.
139 dnnl_status_t DNNL_API dnnl_primitive_desc_destroy(
140         dnnl_primitive_desc_t primitive_desc);
141 
142 /// Queries a primitive descriptor for various pieces of information.
143 ///
144 /// The most common use case is to query a primitive descriptor, created with
145 /// source, weights, and destination memory descriptors with format tags set
146 /// to #dnnl_format_tag_any, for the corresponding memory descriptors (in this
147 /// case the @p what is set to #dnnl_query_src_md, #dnnl_query_weights_md, and
148 /// #dnnl_query_dst_md respectively) so that it is possible to create memory
149 /// objects and reorder primitives if necessary.
150 ///
151 /// Another typical use case is to query a primitive descriptor for workspace
152 /// memory descriptor (with @p what set to #dnnl_query_workspace_md). If this
153 /// query returns #dnnl_not_required status, then workspace memory is not
154 /// required.
155 ///
156 /// @note
157 ///     When querying for a memory descriptor for a scratchpad, a workspace,
158 ///     or an optional parameter, the query will return a pointer to a zero
159 ///     memory descriptor if the parameter is not needed.
160 ///
161 /// A few other use cases:
162 ///  - query a primitive descriptor for the underlying operation descriptor
163 ///    (#dnnl_query_convolution_d, #dnnl_query_eltwise_d, #dnnl_query_rnn_d,
164 ///    etc.)
165 ///  - query a primitive descriptor for the implementation information string
166 ///    (#dnnl_query_impl_info_str)
167 ///  - query a primitive descriptor for the number of inputs and outputs
168 ///    (#dnnl_query_num_of_inputs_s32 and #dnnl_query_num_of_outputs_s32
169 ///    respectively)
170 ///
171 /// @sa dnnl_query_t for more options
172 ///
173 /// @param primitive_desc Primitive descriptor.
174 /// @param what Parameter to query.
175 /// @param index Index of the parameter to query for.
176 /// @param result Output result. The type depends on the query. For example,
177 ///     it must be a @c dnnl_memory_desc_t* if querying for a memory
178 ///     descriptor.
179 /// @returns #dnnl_success on success and a status describing the error
180 ///     otherwise.
181 dnnl_status_t DNNL_API dnnl_primitive_desc_query(
182         const_dnnl_primitive_desc_t primitive_desc, dnnl_query_t what,
183         int index, void *result);
184 
185 /// Queries primitive descriptor for a memory descriptor.
186 ///
187 /// @note
188 ///     This function is a convenience version of
189 ///     #dnnl_primitive_desc_query().
190 ///
191 /// @param primitive_desc Primitive descriptor.
192 /// @param what Kind of memory descriptor parameter to query for.
193 /// @param index Index of the parameter to query.
194 /// @returns A pointer to the requested memory descriptor.
195 /// @returns A pointer to a zero memory descriptor if the parameter is not
196 ///          needed.
197 /// @returns NULL in case of any error.
198 ///
199 const dnnl_memory_desc_t DNNL_API *dnnl_primitive_desc_query_md(
200         const_dnnl_primitive_desc_t primitive_desc, dnnl_query_t what,
201         int index);
202 
203 /// Queries primitive descriptor for a signed 32bit int.
204 ///
205 /// @note
206 ///     This function is a convenience version of
207 ///     #dnnl_primitive_desc_query().
208 ///
209 /// @param primitive_desc Primitive descriptor.
210 /// @param what Kind of the value to query for.
211 /// @param index Index of the parameter to query.
212 /// @returns The requested value.
213 /// @returns 0 in case of any error (in particular if the queried entity is
214 ///     not of type int32_t). Note that 0 may also be the actual returned
215 ///     value.
216 int DNNL_API dnnl_primitive_desc_query_s32(
217         const_dnnl_primitive_desc_t primitive_desc, dnnl_query_t what,
218         int index);
219 
220 /// Creates a primitive.
221 ///
222 /// @param primitive Output primitive.
223 /// @param primitive_desc Primitive descriptor used to create the primitive.
224 /// @returns #dnnl_success on success and a status describing the error
225 ///     otherwise.
226 dnnl_status_t DNNL_API dnnl_primitive_create(dnnl_primitive_t *primitive,
227         const_dnnl_primitive_desc_t primitive_desc);
228 
229 /// Executes a primitive.
230 ///
231 /// @param primitive Primitive to execute.
232 /// @param stream Stream to use.
233 /// @param nargs Number of arguments.
234 /// @param args Array of arguments. Each argument is an
235 ///     <index, #dnnl_memory_t> pair. The index is one of the `DNNL_ARG_*`
236 ///     values such as `DNNL_ARG_SRC`. Unless runtime shapes are used (see
237 ///     #DNNL_RUNTIME_DIM_VAL), the memory object must have the same memory
238 ///     descriptor as that returned by
239 ///     #dnnl_primitive_desc_query_md(#dnnl_query_exec_arg_md, index).
240 /// @returns #dnnl_success on success and a status describing the error
241 ///     otherwise.
242 
243 /// @note If any argument in @param args is padded (padded_dims >
244 /// dims), the primitive execution will assume properly zero-padded
245 /// input arguments, and produce zero-padded output arguments.
246 dnnl_status_t DNNL_API dnnl_primitive_execute(const_dnnl_primitive_t primitive,
247         dnnl_stream_t stream, int nargs, const dnnl_exec_arg_t *args);
248 
249 /// Retrieves a constant reference to the primitive descriptor of a given
250 /// primitive.
251 ///
252 /// @warning
253 ///     It is an error to destroy the returned object. It is owned by the
254 ///     primitive. The @c const qualifier of the returned object prevents
255 ///     such attempts.
256 ///
257 /// @param primitive Primitive to query for the primitive descriptor.
258 /// @param primitive_desc Output primitive descriptor.
259 /// @returns #dnnl_success on success and a status describing the error
260 ///     otherwise.
261 dnnl_status_t DNNL_API dnnl_primitive_get_primitive_desc(
262         const_dnnl_primitive_t primitive,
263         const_dnnl_primitive_desc_t *primitive_desc);
264 
265 /// Destroys a primitive.
266 ///
267 /// @param primitive The primitive to destroy.
268 /// @returns #dnnl_success on success and a status describing the error
269 ///     otherwise.
270 dnnl_status_t DNNL_API dnnl_primitive_destroy(dnnl_primitive_t primitive);
271 
272 /// @} dnnl_api_primitives_common
273 
274 /// @addtogroup dnnl_api_attributes
275 /// @{
276 
277 /// Creates an empty (default) primitive attributes with all the parameters
278 /// set to their default values.
279 ///
280 /// Empty attributes are implied whenever the respective argument is NULL.
281 ///
282 /// @param attr Output primitive attributes.
283 /// @returns #dnnl_success on success and a status describing the error
284 ///     otherwise.
285 dnnl_status_t DNNL_API dnnl_primitive_attr_create(dnnl_primitive_attr_t *attr);
286 
287 /// Clones primitive attributes.
288 ///
289 /// @param attr Output primitive attributes.
290 /// @param existing_attr Primitive attributes to clone.
291 /// @returns #dnnl_success on success and a status describing the error
292 ///     otherwise.
293 dnnl_status_t DNNL_API dnnl_primitive_attr_clone(
294         dnnl_primitive_attr_t *attr, const_dnnl_primitive_attr_t existing_attr);
295 
296 /// Destroys primitive attributes.
297 ///
298 /// @param attr Primitive attributes to destroy.
299 /// @returns #dnnl_success on success and a status describing the error
300 ///     otherwise.
301 dnnl_status_t DNNL_API dnnl_primitive_attr_destroy(dnnl_primitive_attr_t attr);
302 
303 /// Returns the primitive attributes scratchpad mode.
304 ///
305 /// @param attr Primitive attributes.
306 /// @param mode Output scratchpad mode.
307 /// @returns #dnnl_success on success and a status describing the error
308 ///     otherwise.
309 dnnl_status_t DNNL_API dnnl_primitive_attr_get_scratchpad_mode(
310         const_dnnl_primitive_attr_t attr, dnnl_scratchpad_mode_t *mode);
311 
312 /// Sets primitive attributes scratchpad mode.
313 ///
314 /// @param attr Primitive attributes.
315 /// @param mode Scratchpad mode. The possible values are:
316 ///     #dnnl_scratchpad_mode_library (default) and
317 ///     #dnnl_scratchpad_mode_user.
318 /// @returns #dnnl_success on success and a status describing the error
319 ///     otherwise.
320 dnnl_status_t DNNL_API dnnl_primitive_attr_set_scratchpad_mode(
321         dnnl_primitive_attr_t attr, dnnl_scratchpad_mode_t mode);
322 
323 /// Returns primitive attributes output scaling factors correspondence mask
324 /// and values.
325 ///
326 /// @warning
327 ///     The @p scales array is an internal part of the primitive attributes
328 ///     @p attr, so it is an error to modify or destroy the @p scales array.
329 ///
330 /// @warning
331 ///     The lifetime of @p scales array is the same as that of the primitive
332 ///     attributes @p attr to which it belongs, so it is an error to use
333 ///     @p scales after @p attr is destroyed.
334 ///
335 /// @param attr Primitive attributes.
336 /// @param count Output length of the array of scaling factors @p scales.
337 /// @param mask Output scaling factors correspondence mask that defines the
338 ///     correspondence between the output tensor dimensions and the @p scales
339 ///     vector. The set i-th bit indicates that a dedicated output scaling
340 ///     factor is used for each index along that dimension. The mask value of
341 ///     0 implies a common output scaling factor for the whole output tensor.
342 /// @param scales Output pointer to a constant array of scaling factors.
343 /// @returns #dnnl_success on success and a status describing the error
344 ///     otherwise.
345 dnnl_status_t DNNL_API dnnl_primitive_attr_get_output_scales(
346         const_dnnl_primitive_attr_t attr, dnnl_dim_t *count, int *mask,
347         const float **scales);
348 
349 /// Sets output scaling factors correspondence mask and values.
350 ///
351 /// @note
352 ///     The order of dimensions does not depend on how elements are laid
353 ///     out in memory. For example:
354 ///     - for a 2D CNN activations tensor the order is always (n, c)
355 ///     - for a 4D CNN activations tensor the order is always (n, c, h, w)
356 ///     - for a 5D CNN weights tensor the order is always
357 ///        (g, oc, ic, kh, kw)
358 ///
359 /// Example usage:
360 /// @code
361 ///     int mb = 32, oc = 32, oh = 14, ow = 14; // convolution output params
362 ///     float scales[oc] = { ... }; // unique output scales per output channel
363 ///     int oc_dim = 1; // mb_dim = 0, channel_dim = 1, height_dim = 2, ...
364 ///
365 ///     dnnl_convolution_desc_t conv_d; // create a convolution descriptor
366 ///
367 ///     dnnl_primitive_attr_t attr;
368 ///     dnnl_primitive_attr_create(&attr); // create primitive attributes
369 ///     dnnl_primitive_attr_set_output_scales(attr, oc, 1 << oc_dim, scales);
370 ///
371 ///     dnnl_primitive_desc_t conv_pd;
372 ///     dnnl_primitive_desc_create(&conv_pd, &conv_d, attr, engine, NULL);
373 /// @endcode
374 ///
375 /// @param attr Primitive attributes.
376 /// @param count Length of the array of scaling factors @p scales.
377 /// @param mask Scaling factors correspondence mask that defines the
378 ///     correspondence between the output tensor dimensions and the @p scales
379 ///     array. The set i-th bit indicates that a dedicated output scaling
380 ///     factor is used for each index along that dimension. The mask value of
381 ///     0 implies a common output scaling factor for the whole output tensor.
382 /// @param scales Array of output scaling factors. If the output scaling
383 ///     factors are known at the time of this call, this array must contain @p
384 ///     count values and the following equality must hold:
385 ///     \f[count = \prod\limits_{d \in mask} output.dims[d].\f]
386 ///     Violations can only be detected when the attributes are used to create
387 ///     a primitive descriptor.
388 ///     If the output scaling factors are not known at the time of the call,
389 ///     this array must contain a single #DNNL_RUNTIME_F32_VAL value and the
390 ///     output scaling factors must be passed at execution time as an argument
391 ///     with index #DNNL_ARG_ATTR_OUTPUT_SCALES.
392 /// @returns #dnnl_success on success and a status describing the error
393 ///     otherwise.
394 dnnl_status_t DNNL_API dnnl_primitive_attr_set_output_scales(
395         dnnl_primitive_attr_t attr, dnnl_dim_t count, int mask,
396         const float *scales);
397 
398 /// Returns primitive attributes scaling factors correspondence mask and values
399 /// for a given memory argument.
400 ///
401 /// @warning
402 ///     The output @p scales array is an internal part of the primitive
403 ///     attributes @p attr, so it is an error to modify or destroy the @p
404 ///     scales array.
405 ///
406 /// @warning
407 ///     The lifetime of the @p scales array is the same as that of the primitive
408 ///     attributes @p attr to which it belongs, so it is an error to use @p
409 ///     scales after @p attr is destroyed.
410 ///
411 ///
412 /// @param attr Primitive attributes.
413 /// @param arg Parameter argument index as passed to the
414 ///     dnnl_primitive_execute() call.
415 /// @param count Output length of the array of scaling factors @p scales.
416 /// @param mask Output scaling factors correspondence mask that defines the
417 ///     correspondence between the output tensor dimensions and the @p
418 ///     scales array. The set i-th bit indicates that a dedicated output scaling
419 ///     factor is used for each index along that dimension. The mask value of 0
420 ///     implies a common scaling factor for the whole output tensor.
421 /// @param scales Output pointer to a constant array of float scaling factors.
422 /// @returns #dnnl_success on success and a status describing the error
423 ///     otherwise.
424 dnnl_status_t DNNL_API dnnl_primitive_attr_get_scales(
425         dnnl_primitive_attr_t attr, int arg, dnnl_dim_t *count, int *mask,
426         const float **scales);
427 
428 /// Sets primitive attributes scaling factors for primitive operations for a
429 /// given memory argument.
430 ///
431 /// @sa dnnl_primitive_attr_set_output_scales
432 ///
433 ///
434 /// @param attr Primitive attributes.
435 /// @param arg Parameter argument index as passed to the
436 ///     dnnl_primitive_execute() call.
437 /// @param count Length of the array of scaling factors @p scales.
438 /// @param mask Scaling factors correspondence mask that defines the
439 ///     correspondence between the tensor dimensions and the @p scales array.
440 ///     The set i-th bit indicates that a dedicated scaling factor is used for
441 ///     each index along that dimension. Set the mask to 0 to use a common
442 ///     scaling factor for the whole output tensor.
443 /// @param scales Constant array of float scaling factors. This array must
444 ///     contain @p count scales and the following equality must hold:
445 ///     \f[count = \prod\limits_{d \in mask} output.dims[d].\f]
446 /// @returns #dnnl_success on success and a status describing the error
447 ///     otherwise.
448 dnnl_status_t DNNL_API dnnl_primitive_attr_set_scales(
449         dnnl_primitive_attr_t attr, int arg, dnnl_dim_t count, int mask,
450         const float *scales);
451 
452 /// Returns @p count, correspondence zero point @p mask, and a pointer to a
453 /// constant int32_t array of @p zero_points for given @p attr and memory
454 /// argument (index), previously set by dnnl_primitive_attr_set_zero_points.
455 ///
456 /// @warning
457 ///     The output @p zero_points array is an internal part of the primitive
458 ///     attributes @p attr, so it is an error to modify or destroy the @p
459 ///     zero_points array.
460 ///
461 /// @warning
462 ///     The lifetime of @p zero_points array is the same as that of the
463 ///     primitive attributes @p attr to which it belongs, so it is an error
464 ///     to use @p zero_points after @p attr is destroyed.
465 ///
466 ///
467 /// @param attr Primitive attributes.
468 /// @param arg Parameter argument index as passed to the
469 ///     dnnl_primitive_execute() call.
470 /// @param count Output length of the array of zero points @p zero_points.
471 /// @param mask Output zero points correspondence mask that defines the
472 ///     correspondence between the output tensor dimensions and the @p
473 ///     zero_points array. The set i-th bit indicates that a dedicated output
474 ///     zero point is used for each index along that dimension. The mask
475 ///     value of 0 implies a common zero point for the whole output tensor.
476 /// @param zero_points Output pointer to a constant array of int32_t zero
477 ///     points.
478 /// @returns #dnnl_success on success and a status describing the error
479 ///     otherwise.
480 dnnl_status_t DNNL_API dnnl_primitive_attr_get_zero_points(
481         const_dnnl_primitive_attr_t attr, int arg, dnnl_dim_t *count, int *mask,
482         const int32_t **zero_points);
483 
484 /// Sets primitive attributes zero points for primitive operations for a given
485 /// memory argument.
486 ///
487 /// @sa dnnl_primitive_attr_set_output_scales
488 ///
489 ///
490 /// @param attr Primitive attributes.
491 /// @param arg Parameter argument index as passed to the
492 ///     dnnl_primitive_execute() call.
493 /// @param count Length of the array of zero points @p zero_points.
494 /// @param mask Zero point correspondence mask that defines the
495 ///     correspondence between the tensor dimensions and the @p
496 ///     zero_points array. The set i-th bit indicates that a dedicated
497 ///     zero point is used for each index along that dimension. Set the
498 ///     mask to 0 to use a common zero point for the whole output tensor.
499 /// @param zero_points Constant array of int32_t zero points. If the zero
500 ///     points are known at the time of this call, this array must contain @p
501 ///     count zero points and the following equality must hold:
502 ///     \f[count = \prod\limits_{d \in mask} output.dims[d].\f]
503 ///     If the zero points are not known at the time of the call, this array
504 ///     must contain a single #DNNL_RUNTIME_S32_VAL and the zero points must
505 ///     be passed at execution time as an argument with index
506 ///     #DNNL_ARG_ATTR_ZERO_POINTS.
507 /// @returns #dnnl_success on success and a status describing the error
508 ///     otherwise.
509 dnnl_status_t DNNL_API dnnl_primitive_attr_set_zero_points(
510         dnnl_primitive_attr_t attr, int arg, dnnl_dim_t count, int mask,
511         const int32_t *zero_points);
512 
513 /// Returns primitive attributes post-ops.
514 ///
515 /// @warning
516 ///     The output @p post_ops points to the internal @p attr field, so it is
517 ///     an error to modify or destroy them. The lifetime of @p post_ops is
518 ///     the same as that of the @p attr it belongs to, so it is an error to
519 ///     use @p post_ops after @p attr has been destroyed.
520 ///
521 /// @param attr Primitive attributes.
522 /// @param post_ops Output post-ops.
523 /// @returns #dnnl_success on success and a status describing the error
524 ///     otherwise.
525 dnnl_status_t DNNL_API dnnl_primitive_attr_get_post_ops(
526         const_dnnl_primitive_attr_t attr, const_dnnl_post_ops_t *post_ops);
527 
528 /// Sets primitive attributes post-ops.
529 ///
530 /// @note
531 ///     There is no way to check whether the post-ops would be supported by
532 ///     the target primitive. Any error will be reported by the
533 ///     dnnl_primitive_desc_create() function call.
534 ///
535 /// @param attr Primitive attributes.
536 /// @param post_ops Post-ops to set.
537 /// @returns #dnnl_success on success and a status describing the error
538 ///     otherwise.
539 dnnl_status_t DNNL_API dnnl_primitive_attr_set_post_ops(
540         dnnl_primitive_attr_t attr, const_dnnl_post_ops_t post_ops);
541 
542 /// Creates empty post-ops sequence.
543 ///
544 /// @param post_ops Output post-ops.
545 /// @returns #dnnl_success on success and a status describing the error
546 ///     otherwise.
547 dnnl_status_t DNNL_API dnnl_post_ops_create(dnnl_post_ops_t *post_ops);
548 
549 /// Destroys post-ops.
550 ///
551 /// @param post_ops Post-ops to destroy.
552 /// @returns #dnnl_success on success and a status describing the error
553 ///     otherwise.
554 dnnl_status_t DNNL_API dnnl_post_ops_destroy(dnnl_post_ops_t post_ops);
555 
556 /// Returns the length of post-ops.
557 ///
558 /// @param post_ops Post-ops.
559 /// @returns The number of post-ops entries.
560 int DNNL_API dnnl_post_ops_len(const_dnnl_post_ops_t post_ops);
561 
562 /// Returns the kind of a post-op entry.
563 ///
564 /// @param post_ops Post-ops.
565 /// @param index Post-op entry index.
566 /// @returns The kind of the post-op with the specified index.
567 /// @returns #dnnl_undefined_primitive if there is no post-op at the specified
568 ///     index.
569 dnnl_primitive_kind_t DNNL_API dnnl_post_ops_get_kind(
570         const_dnnl_post_ops_t post_ops, int index);
571 
572 /// Appends an accumulation (sum) to post-ops. Prior to accumulating the
573 /// result, the previous value is multiplied by a scale.
574 ///
575 /// The kind of this post-op is #dnnl_sum.
576 ///
577 /// This feature may improve performance for cases like residual learning
578 /// blocks, where the result of convolution is accumulated to the previously
579 /// computed activations. The parameter @p scale may be used for the
580 /// integer-based computations when the result and previous activations have
581 /// different logical scaling factors.
582 ///
583 /// In the simplest case when the accumulation is the only post-op, the
584 /// computations would be:
585 ///
586 ///     dst[:] <- scale * dst[:] + op(...) // instead of dst[:] <- op(...)
587 ///
588 /// @note
589 ///     This post-op executes in-place and does not change the
590 ///     destination layout.
591 ///
592 /// @param post_ops Post-ops.
593 /// @param scale Accumulation scaling factor.
594 /// @returns #dnnl_success on success and a status describing the error
595 ///     otherwise.
596 dnnl_status_t DNNL_API dnnl_post_ops_append_sum(
597         dnnl_post_ops_t post_ops, float scale);
598 
599 /// Appends an accumulation v2 (sum) to post-ops. Prior to accumulating the
600 /// result, the previous value is multiplied by a scale.
601 ///
602 /// The kind of this post-op is #dnnl_sum.
603 ///
604 /// This feature may improve performance for cases like residual learning
605 /// blocks, where the result of convolution is accumulated to the previously
606 /// computed activations. The parameter @p scale may be used for the
607 /// integer-based computations when the result and previous activations have
608 /// different logical scaling factors.
609 ///
610 /// In the simplest case when the accumulation is the only post-op, the
611 /// computations would be:
612 ///
613 ///     dst[:] <- scale * dst[:] + op(...) // instead of dst[:] <- op(...)
614 ///
615 /// If @p data_type is specified, original dst tensor will be reinterpreted
616 /// as a tensor with provided data type. Since it is reinterpretation,
617 /// data_type and dst data type should have same size.
618 /// As a result, computations would be:
619 ///
620 ///     dst[:] <- scale * as_data_type(dst[:]) + op(...)
621 ///                                        // instead of dst[:] <- op(...)
622 /// @note
623 ///     This post-op executes in-place and does not change the
624 ///     destination layout.
625 ///
626 /// @param post_ops Post-ops.
627 /// @param scale Accumulation scaling factor.
628 /// @param data_type Accumulation data_type.
629 /// @returns #dnnl_success on success and a status describing the error
630 ///     otherwise.
631 dnnl_status_t DNNL_API dnnl_post_ops_append_sum_v2(
632         dnnl_post_ops_t post_ops, float scale, dnnl_data_type_t data_type);
633 
634 /// Returns the parameters of an accumulation (sum) post-op.
635 ///
636 /// @param post_ops Post-ops.
637 /// @param index Index of the sum post-op.
638 /// @param scale Output accumulation scaling factor.
639 /// @returns #dnnl_success on success and a status describing the error
640 ///     otherwise.
641 /// @returns #dnnl_invalid_arguments if @p index does not refer to a sum
642 ///     post-op.
643 dnnl_status_t DNNL_API dnnl_post_ops_get_params_sum(
644         const_dnnl_post_ops_t post_ops, int index, float *scale);
645 
646 /// Returns the parameters of an accumulation (sum) post-op with
647 /// a data type parameter.
648 ///
649 /// @param post_ops Post-ops.
650 /// @param index Index of the sum post-op.
651 /// @param scale Output accumulation scaling factor.
652 /// @param data_type Data type for accumulation.
653 /// @returns #dnnl_success on success and a status describing the error
654 ///     otherwise.
655 dnnl_status_t DNNL_API dnnl_post_ops_get_params_sum_v2(
656         const_dnnl_post_ops_t post_ops, int index, float *scale,
657         dnnl_data_type_t *data_type);
658 
659 /// Appends an elementwise post-op.
660 ///
661 /// The kind of this post operation is #dnnl_eltwise.
662 ///
663 /// In the simplest case when the elementwise is the only post operation, the
664 /// computations would be:
665 ///
666 ///     dst[:] <- scale * eltwise_op (op(...)) // instead of dst[:] <- op(...)
667 ///
668 /// where eltwise_op is configured with the given parameters.
669 ///
670 /// @param post_ops Post-ops.
671 /// @param scale Scaling factor.
672 /// @param alg_kind Elementwise algorithm for the post-op.
673 /// @param alpha Alpha parameter for the elementwise algorithm.
674 /// @param beta Beta parameter for the elementwise algorithm.
675 /// @returns #dnnl_success on success and a status describing the error
676 ///     otherwise.
677 dnnl_status_t DNNL_API dnnl_post_ops_append_eltwise(dnnl_post_ops_t post_ops,
678         float scale, dnnl_alg_kind_t alg_kind, float alpha, float beta);
679 
680 /// Returns the parameters of an elementwise post-op.
681 ///
682 /// @param post_ops Post-ops.
683 /// @param index Index of the elementwise post-op.
684 /// @param scale Output scaling factor.
685 /// @param alg_kind Output elementwise algorithm kind.
686 /// @param alpha Output alpha parameter for the elementwise algorithm.
687 /// @param beta Output beta parameter for the elementwise algorithm.
688 /// @returns #dnnl_success on success and a status describing the error
689 ///     otherwise.
690 /// @returns #dnnl_invalid_arguments if @p index does not refer to an
691 ///     elementwise post-op.
692 dnnl_status_t DNNL_API dnnl_post_ops_get_params_eltwise(
693         const_dnnl_post_ops_t post_ops, int index, float *scale,
694         dnnl_alg_kind_t *alg_kind, float *alpha, float *beta);
695 
696 /// Appends a depthwise post-op convolution with stride 1.
697 ///
698 /// This post-op can only be fused with a 2D 1x1 convolution (convolution with
699 /// weights spatial dimension equal to 1 i.e., kh=kw=1).
700 ///
701 /// The kind of this post-op is #dnnl_convolution.
702 ///
703 /// The number of outputs for primitive remain same as before. The output size
704 /// remain same as the original primitive due to stride=1.
705 ///
706 /// The Post-op can be defined as:
707 ///
708 ///      dst[:] <- scales * (conv_dw(conv_1x1))
709 ///
710 /// See @ref dev_guide_attributes_post_ops_depthwise and
711 /// @ref dev_guide_attributes_post_ops_depthwise_fusion for more info.
712 ///
713 /// @param post_ops Post-ops.
714 /// @param weights_data_type Weights data type of depthwise post-op
715 /// @param bias_data_type Bias data type of depthwise post-op
716 /// @param dst_data_type Output data type of depthwise post-op
717 /// @param count Output length of the array of scaling factors @p scales.
718 /// @param mask Output scaling factors correspondence mask that defines the
719 ///     correspondence between the output tensor dimensions and the @p
720 ///     scales array. The set i-th bit indicates that a dedicated output scaling
721 ///     factor is used for each index along that dimension. The mask value of 0
722 ///     implies a common scaling factor for the whole output tensor.
723 /// @param scales Output pointer to a constant array of float scaling factors.
724 /// @returns #dnnl_success on success and a status describing the error
725 ///     otherwise
726 dnnl_status_t DNNL_API dnnl_post_ops_append_dw_k3s1p1(dnnl_post_ops_t post_ops,
727         dnnl_data_type_t weights_data_type, dnnl_data_type_t bias_data_type,
728         dnnl_data_type_t dst_data_type, dnnl_dim_t count, int mask,
729         const float *scales);
730 
731 /// Returns the parameters of an depthwise post-op with stride 1.
732 ///
733 /// @param post_ops Post-ops.
734 /// @param index Index of the elementwise post-op.
735 /// @param weights_data_type Weights data type of depthwise post-op
736 /// @param bias_data_type Bias data type of depthwise post-op
737 /// @param dst_data_type Output data type of depthwise post-op
738 /// @param count Output length of the array of scaling factors @p scales.
739 /// @param mask Output scaling factors correspondence mask that defines the
740 ///     correspondence between the output tensor dimensions and the @p
741 ///     scales array. The set i-th bit indicates that a dedicated output scaling
742 ///     factor is used for each index along that dimension. The mask value of 0
743 ///     implies a common scaling factor for the whole output tensor.
744 /// @param scales Output pointer to a constant array of float scaling factors.
745 /// @returns #dnnl_success on success and a status describing the error
746 ///     otherwise
747 dnnl_status_t DNNL_API dnnl_post_ops_get_params_dw_k3s1p1(
748         const_dnnl_post_ops_t post_ops, int index,
749         dnnl_data_type_t *weights_data_type, dnnl_data_type_t *bias_data_type,
750         dnnl_data_type_t *dst_data_type, dnnl_dim_t *count, int *mask,
751         const float **scales);
752 
753 /// Appends a depthwise post-op convolution with stride 2.
754 ///
755 /// This post-op can only be fused with a 2D 1x1 convolution (convolution with
756 /// weights spatial dimension equal to 1 i.e., kh=kw=1).
757 ///
758 /// The kind of this post-op is #dnnl_convolution.
759 ///
760 /// The number of outputs for primitive remain same as before. The output
761 /// spatial size can be derived as below:
762 ///
763 /// output_height = ceil(output_height_1x1_convolution, stride)
764 /// output_width = ceil(output_width_1x1_convolution, stride)
765 ///
766 /// The Post-op can be defined as:
767 ///
768 ///      dst[:] <- scales * (conv_dw(conv_1x1))
769 ///
770 /// See @ref dev_guide_attributes_post_ops_depthwise and
771 /// @ref dev_guide_attributes_post_ops_depthwise_fusion for more info.
772 ///
773 /// @param post_ops Post-ops.
774 /// @param weights_data_type Weights data type of depthwise post-op
775 /// @param bias_data_type Bias data type of depthwise post-op
776 /// @param dst_data_type Output data type of depthwise post-op
777 /// @param count Output length of the array of scaling factors @p scales.
778 /// @param mask Output scaling factors correspondence mask that defines the
779 ///     correspondence between the output tensor dimensions and the @p
780 ///     scales array. The set i-th bit indicates that a dedicated output scaling
781 ///     factor is used for each index along that dimension. The mask value of 0
782 ///     implies a common scaling factor for the whole output tensor.
783 /// @param scales Output pointer to a constant array of float scaling factors.
784 /// @returns #dnnl_success on success and a status describing the error
785 ///     otherwise
786 dnnl_status_t DNNL_API dnnl_post_ops_append_dw_k3s2p1(dnnl_post_ops_t post_ops,
787         dnnl_data_type_t weights_data_type, dnnl_data_type_t bias_data_type,
788         dnnl_data_type_t dst_data_type, dnnl_dim_t count, int mask,
789         const float *scales);
790 
791 /// Returns the parameters of an depthwise post-op with stride 2.
792 ///
793 /// @param post_ops Post-ops.
794 /// @param index Index of the elementwise post-op.
795 /// @param weights_data_type Weights data type of depthwise post-op
796 /// @param bias_data_type Bias data type of depthwise post-op
797 /// @param dst_data_type Output data type of depthwise post-op
798 /// @param count Output length of the array of scaling factors @p scales.
799 /// @param mask Output scaling factors correspondence mask that defines the
800 ///     correspondence between the output tensor dimensions and the @p
801 ///     scales array. The set i-th bit indicates that a dedicated output scaling
802 ///     factor is used for each index along that dimension. The mask value of 0
803 ///     implies a common scaling factor for the whole output tensor.
804 /// @param scales Output pointer to a constant array of float scaling factors.
805 /// @returns #dnnl_success on success and a status describing the error
806 ///     otherwise
807 dnnl_status_t DNNL_API dnnl_post_ops_get_params_dw_k3s2p1(
808         const_dnnl_post_ops_t post_ops, int index,
809         dnnl_data_type_t *weights_data_type, dnnl_data_type_t *bias_data_type,
810         dnnl_data_type_t *dst_data_type, dnnl_dim_t *count, int *mask,
811         const float **scales);
812 
813 /// Appends a binary post-op.
814 ///
815 /// The kind of this post operation is #dnnl_binary.
816 ///
817 /// In the simplest case when the binary is the only post operation, the
818 /// computations would be:
819 ///
820 ///     dst[:] <- binary_op (dst[:], another_input[:])
821 ///
822 /// where binary_op is configured with the given parameters. binary_op supports
823 /// broadcast semantics for a second operand.
824 ///
825 /// @param post_ops Post-ops.
826 /// @param alg_kind Binary algorithm for the post-op.
827 /// @param src1_desc Memory descriptor of a second operand.
828 /// @returns #dnnl_success on success and a status describing the error
829 ///     otherwise.
830 dnnl_status_t DNNL_API dnnl_post_ops_append_binary(dnnl_post_ops_t post_ops,
831         dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src1_desc);
832 
833 /// Returns the parameters of a binary post-op.
834 ///
835 /// @param post_ops Post-ops.
836 /// @param index Index of the binary post-op.
837 /// @param alg_kind Output binary algorithm kind.
838 /// @param src1_desc Output memory descriptor of a second operand.
839 /// @returns #dnnl_success on success and a status describing the error
840 ///     otherwise.
841 /// @returns #dnnl_invalid_arguments if @p index does not refer to a binary
842 ///     post-op.
843 dnnl_status_t DNNL_API dnnl_post_ops_get_params_binary(
844         const_dnnl_post_ops_t post_ops, int index, dnnl_alg_kind_t *alg_kind,
845         const dnnl_memory_desc_t **src1_desc);
846 
847 /// @} dnnl_api_attributes
848 
849 /// @} dnnl_api_primitives
850 
851 /// @addtogroup dnnl_api_memory
852 /// @{
853 
854 /// Initializes a memory descriptor using dimensions and strides.
855 ///
856 /// @note
857 ///     As always, the logical order of dimensions corresponds to the `abc...`
858 ///     format tag, and the physical meaning of the dimensions depends on both
859 ///     the primitive that consumes the memory and the context of that
860 ///     consumption.
861 ///
862 /// @param memory_desc Output memory descriptor.
863 /// @param ndims Number of dimensions
864 /// @param dims Array of dimensions.
865 /// @param data_type Elements data type.
866 /// @param strides Strides in each dimension.
867 /// @returns #dnnl_success on success and a status describing the error
868 ///     otherwise.
869 dnnl_status_t DNNL_API dnnl_memory_desc_init_by_strides(
870         dnnl_memory_desc_t *memory_desc, int ndims, const dnnl_dims_t dims,
871         dnnl_data_type_t data_type, const dnnl_dims_t strides);
872 
873 /// Initializes a memory descriptor using dimensions and memory format tag.
874 ///
875 /// @note
876 ///     As always, the logical order of dimensions corresponds to the `abc...`
877 ///     format tag, and the physical meaning of the dimensions depends on both
878 ///     the primitive that consumes the memory and the context of that
879 ///     consumption.
880 ///
881 /// @param memory_desc Output memory descriptor.
882 /// @param ndims Number of dimensions
883 /// @param dims Array of dimensions.
884 /// @param data_type Elements data type.
885 /// @param tag Memory format tag. Can be #dnnl_format_tag_any which would
886 ///     allow a primitive to chose the final memory format. In this case the
887 ///     format_kind field of the memory descriptor would be set to
888 ///     #dnnl_format_kind_any.
889 /// @returns #dnnl_success on success and a status describing the error
890 ///     otherwise.
891 dnnl_status_t DNNL_API dnnl_memory_desc_init_by_tag(
892         dnnl_memory_desc_t *memory_desc, int ndims, const dnnl_dims_t dims,
893         dnnl_data_type_t data_type, dnnl_format_tag_t tag);
894 
895 /// Initializes a memory descriptor for a region inside an area
896 /// described by an existing memory descriptor.
897 ///
898 /// @warning
899 ///     Some combinations of physical memory layout and/or offsets or dims may
900 ///     result in a failure to create a submemory.
901 //
902 /// @param memory_desc Output memory descriptor.
903 /// @param parent_memory_desc An existing memory descriptor.
904 /// @param dims Sizes of the region.
905 /// @param offsets Offsets to the region from the encompassing
906 ///     memory object in each dimension
907 /// @returns #dnnl_success on success and a status describing the error
908 ///     otherwise.
909 dnnl_status_t DNNL_API dnnl_memory_desc_init_submemory(
910         dnnl_memory_desc_t *memory_desc,
911         const dnnl_memory_desc_t *parent_memory_desc, const dnnl_dims_t dims,
912         const dnnl_dims_t offsets);
913 
914 /// Initializes a memory descriptor by reshaping an existing one. The new
915 /// memory descriptor inherits the data type. This operation is valid only for
916 /// memory descriptors that have format_kind set to #dnnl_blocked or
917 /// #dnnl_format_kind_any.
918 ///
919 /// The operation ensures the transformation of the physical memory format
920 /// corresponds to the transformation of the logical dimensions. If such
921 /// transformation is impossible, the function returns #dnnl_invalid_arguments.
922 ///
923 /// The reshape operation can be described as a combination of the following
924 /// basic operations:
925 /// 1. Add a dimension of size `1`. This is always possible.
926 /// 2. Remove a dimension of size `1`. This is possible only if the dimension
927 ///    has no padding (i.e. `padded_dims[dim] == dims[dim] && dims[dim] == 1`).
928 /// 3. Split a dimension into multiple ones. This is possible only if the size
929 ///    of the dimension is exactly equal to the product of the split ones and
930 ///    the dimension does not have padding (i.e.
931 ///    `padded_dims[dim] = dims[dim]`).
932 /// 4. Joining multiple consecutive dimensions into a single one. As in the
933 ///    cases above, this requires that the dimensions do not have padding and
934 ///    that the memory format is such that in physical memory these dimensions
935 ///    are dense and have the same order as their logical counterparts. This
936 ///    also assumes that these dimensions are not blocked.
937 ///    - Here, dense means:
938 ///      `stride for dim[i] == (stride for dim[i + 1]) * dim[i + 1]`;
939 ///    - And same order means:
940 ///      `i < j` if and only if `stride for dim[j] <= stride for dim[i]`.
941 ///
942 /// @warning
943 ///     Some combinations of physical memory layout and/or offsets or
944 ///     dimensions may result in a failure to make a reshape.
945 ///
946 /// @param out_memory_desc Output memory descriptor.
947 /// @param in_memory_desc An existing memory descriptor. Must have format_kind
948 ///     set to #dnnl_blocked or #dnnl_format_kind_any.
949 /// @param ndims Number of dimensions for the output memory descriptor.
950 /// @param dims Dimensions for the output memory descriptor.
951 /// @returns #dnnl_success on success and a status describing the error
952 ///     otherwise.
953 dnnl_status_t DNNL_API dnnl_memory_desc_reshape(
954         dnnl_memory_desc_t *out_memory_desc,
955         const dnnl_memory_desc_t *in_memory_desc, int ndims,
956         const dnnl_dims_t dims);
957 
958 /// Initializes a memory descriptor by permuting axes in an existing one.
959 ///
960 /// The physical memory layout representation is adjusted accordingly to
961 /// maintain the consistency between the logical and physical parts of the
962 /// memory descriptor.
963 ///
964 /// The new memory descriptor inherits the data type. This operation is valid
965 /// only for memory descriptors that have format_kind set to #dnnl_blocked or
966 /// #dnnl_format_kind_any.
967 ///
968 /// The logical axes will be permuted in the following manner:
969 /// ```
970 /// for (i: 0 .. in_memory_desc->ndims)
971 ///     out_memory_desc->dims[permutation[i]] = in_memory_desc->dims[i];
972 /// ```
973 ///
974 /// Example:
975 /// @code
976 ///     dnnl_memory_desc_t in_md, out_md, expect_out_md;
977 ///
978 ///     const int permutation[] = {1, 0}; // swap the first and the second axes
979 ///
980 ///     dnnl_dims_t in_dims = {2, 3}, out_dims = {3, 2};
981 ///     dnnl_format_tag_t in_tag = dnnl_ab, out_tag = dnnl_ba;
982 ///
983 ///     dnnl_memory_desc_init_by_tag(
984 ///             &in_md, 2, in_dims, data_type, in_tag);
985 ///     dnnl_memory_desc_init_by_tag(
986 ///             &expect_out_md, 2, out_dims, data_type, out_tag);
987 ///
988 ///     dnnl_memory_desc_permute_axes(&out_md, in_md, permutation);
989 ///     assert(dnnl_memory_desc_equal(&out_md, &expect_out_md));
990 /// @endcode
991 ///
992 /// @param out_memory_desc Output memory descriptor.
993 /// @param in_memory_desc An existing memory descriptor. Must have format_kind
994 ///     set to #dnnl_blocked or #dnnl_format_kind_any.
995 /// @param permutation Axes permutation (of size `in_memory_desc->ndims`).
996 /// @returns #dnnl_success on success and a status describing the error
997 ///     otherwise.
998 dnnl_status_t DNNL_API dnnl_memory_desc_permute_axes(
999         dnnl_memory_desc_t *out_memory_desc,
1000         const dnnl_memory_desc_t *in_memory_desc, const int *permutation);
1001 
1002 /// Compares two memory descriptors.
1003 ///
1004 /// Use this function to identify whether a reorder is required between the
1005 /// two memories
1006 ///
1007 /// @param lhs Left-hand side of the comparison.
1008 /// @param rhs Right-hand side of the comparison.
1009 /// @returns 1 if the descriptors are the same.
1010 /// @returns 0 if the descriptors are different.
1011 int DNNL_API dnnl_memory_desc_equal(
1012         const dnnl_memory_desc_t *lhs, const dnnl_memory_desc_t *rhs);
1013 
1014 /// Returns the size of a memory descriptor.
1015 ///
1016 /// @param memory_desc Memory descriptor.
1017 /// @returns The number of bytes required for memory described by a memory
1018 ///     descriptor.
1019 size_t DNNL_API dnnl_memory_desc_get_size(
1020         const dnnl_memory_desc_t *memory_desc);
1021 
1022 /// Returns the size of data type.
1023 ///
1024 /// @param data_type Data type.
1025 /// @returns The number of bytes occupied by data type.
1026 size_t DNNL_API dnnl_data_type_size(dnnl_data_type_t data_type);
1027 
1028 /// Creates a memory object.
1029 ///
1030 /// Unless @p handle is equal to DNNL_MEMORY_NONE, the constructed memory
1031 /// object will have the underlying buffer set. In this case, the buffer will
1032 /// be initialized as if dnnl_memory_set_data_handle() had been called.
1033 ///
1034 /// @sa dnnl_memory_set_data_handle()
1035 ///
1036 /// @param memory Output memory object.
1037 /// @param memory_desc Memory descriptor.
1038 /// @param engine Engine to use.
1039 /// @param handle Handle of the memory buffer to use as an underlying storage.
1040 ///     - A pointer to the user-allocated buffer. In this case the library
1041 ///       doesn't own the buffer.
1042 ///     - The DNNL_MEMORY_ALLOCATE special value. Instructs the library to
1043 ///       allocate the buffer for the memory object. In this case the library
1044 ///       owns the buffer.
1045 ///     - DNNL_MEMORY_NONE to create dnnl_memory without an underlying buffer.
1046 /// @returns #dnnl_success on success and a status describing the error
1047 ///     otherwise.
1048 dnnl_status_t DNNL_API dnnl_memory_create(dnnl_memory_t *memory,
1049         const dnnl_memory_desc_t *memory_desc, dnnl_engine_t engine,
1050         void *handle);
1051 
1052 /// Returns the memory descriptor for a memory object.
1053 ///
1054 /// @param memory Memory object.
1055 /// @param memory_desc Output memory descriptor (a copy).
1056 /// @returns #dnnl_success on success and a status describing the error
1057 ///     otherwise.
1058 dnnl_status_t DNNL_API dnnl_memory_get_memory_desc(
1059         const_dnnl_memory_t memory, const dnnl_memory_desc_t **memory_desc);
1060 
1061 /// Returns the engine of a memory object.
1062 ///
1063 /// @param memory Memory object.
1064 /// @param engine Output engine on which the memory is located.
1065 /// @returns #dnnl_success on success and a status describing the error
1066 ///     otherwise.
1067 dnnl_status_t DNNL_API dnnl_memory_get_engine(
1068         const_dnnl_memory_t memory, dnnl_engine_t *engine);
1069 
1070 /// Maps a memory object and returns a host-side pointer to a memory buffer
1071 /// with a copy of its contents.
1072 ///
1073 /// Mapping enables explicit direct access to memory contents for the engines
1074 /// that do not support it implicitly.
1075 ///
1076 /// Mapping is an exclusive operation - a memory object cannot be used in
1077 /// other operations until this memory object is unmapped.
1078 ///
1079 /// @note
1080 ///     Any primitives working with @p memory should be completed before
1081 ///     the memory is mapped. Use dnnl_stream_wait to synchronize the
1082 ///     corresponding execution stream.
1083 ///
1084 /// @note
1085 ///     The dnnl_memory_map_data() and dnnl_memory_unmap_data() functions are
1086 ///     mainly provided for debug and testing purposes, and their performance
1087 ///     may be suboptimal.
1088 ///
1089 /// @param memory Memory object.
1090 /// @param mapped_ptr Output pointer to the mapped buffer.
1091 /// @returns #dnnl_success on success and a status describing the error
1092 ///     otherwise.
1093 dnnl_status_t DNNL_API dnnl_memory_map_data(
1094         const_dnnl_memory_t memory, void **mapped_ptr);
1095 
1096 /// Unmaps a memory object and writes back any changes made to the previously
1097 /// mapped memory buffer. The pointer to the mapped buffer must be obtained
1098 /// via the dnnl_memory_map_data() call.
1099 ///
1100 /// @note
1101 ///     The dnnl_memory_map_data() and dnnl_memory_unmap_data() functions are
1102 ///     mainly provided for debug and testing purposes, and their performance
1103 ///     may be suboptimal.
1104 ///
1105 /// @param memory Memory object.
1106 /// @param mapped_ptr Pointer to the mapped buffer that must have been
1107 ///     obtained using the dnnl_memory_map_data() function.
1108 /// @returns #dnnl_success on success and a status describing the error
1109 ///     otherwise.
1110 dnnl_status_t DNNL_API dnnl_memory_unmap_data(
1111         const_dnnl_memory_t memory, void *mapped_ptr);
1112 
1113 /// Returns memory object's data handle.
1114 ///
1115 /// @param memory Memory object.
1116 /// @param handle Output data handle. For the CPU engine, the data handle is a
1117 ///     pointer to the actual data. For OpenCL it is a cl_mem.
1118 /// @returns #dnnl_success on success and a status describing the error
1119 ///     otherwise.
1120 dnnl_status_t DNNL_API dnnl_memory_get_data_handle(
1121         const_dnnl_memory_t memory, void **handle);
1122 
1123 /// Sets the underlying memory buffer.
1124 ///
1125 /// See the description of dnnl_memory_set_data_handle_v2() for more details.
1126 ///
1127 /// @param memory Memory object.
1128 /// @param handle Data handle. For the CPU engine, the data handle is a
1129 ///     pointer to the actual data. For OpenCL it is a `cl_mem`.
1130 /// @returns #dnnl_success on success and a status describing the error
1131 ///     otherwise.
1132 dnnl_status_t DNNL_API dnnl_memory_set_data_handle(
1133         dnnl_memory_t memory, void *handle);
1134 
1135 /// Sets the underlying memory buffer.
1136 ///
1137 /// @param memory Memory object.
1138 /// @param handle Data handle. For the CPU engine, the data handle is a
1139 ///     pointer to the actual data. For OpenCL it is a `cl_mem`.
1140 /// @param stream Stream to use to execute padding in.
1141 /// @returns #dnnl_success on success and a status describing the error
1142 ///     otherwise.
1143 dnnl_status_t DNNL_API dnnl_memory_set_data_handle_v2(
1144         dnnl_memory_t memory, void *handle, dnnl_stream_t stream);
1145 
1146 /// Destroys a memory object.
1147 ///
1148 /// @param memory Memory object to destroy.
1149 /// @returns #dnnl_success on success and a status describing the error
1150 ///     otherwise.
1151 dnnl_status_t DNNL_API dnnl_memory_destroy(dnnl_memory_t memory);
1152 
1153 /// @} dnnl_api_memory
1154 
1155 /// @addtogroup dnnl_api_primitives
1156 /// @{
1157 
1158 /// @addtogroup dnnl_api_reorder
1159 /// @{
1160 
1161 /// Creates a primitive descriptor for a reorder primitive.
1162 ///
1163 /// @param reorder_primitive_desc Output primitive descriptor.
1164 /// @param src_desc Source memory descriptor.
1165 /// @param src_engine Engine on which the source memory object will be
1166 ///     located.
1167 /// @param dst_desc Destination memory descriptor.
1168 /// @param dst_engine Engine on which the destination memory object
1169 ///     will be located.
1170 /// @param attr Primitive attributes to use (can be NULL).
1171 /// @returns #dnnl_success on success and a status describing the error
1172 ///     otherwise.
1173 dnnl_status_t DNNL_API dnnl_reorder_primitive_desc_create(
1174         dnnl_primitive_desc_t *reorder_primitive_desc,
1175         const dnnl_memory_desc_t *src_desc, dnnl_engine_t src_engine,
1176         const dnnl_memory_desc_t *dst_desc, dnnl_engine_t dst_engine,
1177         const_dnnl_primitive_attr_t attr);
1178 
1179 /// @} dnnl_api_reorder
1180 
1181 /// @addtogroup dnnl_api_concat
1182 /// @{
1183 
1184 /// Creates a primitive descriptor for an out-of-place concatenation
1185 /// primitive.
1186 ///
1187 /// @param concat_primitive_desc Output primitive descriptor.
1188 /// @param dst_desc Destination memory descriptor.
1189 /// @param n Number of source parameters.
1190 /// @param concat_dimension Source tensors will be concatenated over
1191 ///     dimension with this index. Note that order of dimensions does
1192 ///     not depend on memory format.
1193 /// @param src_descs Array of source memory descriptors with @p n elements.
1194 /// @param attr Primitive attributes to use (can be NULL).
1195 /// @param engine Engine to use.
1196 /// @returns #dnnl_success on success and a status describing the error
1197 ///     otherwise.
1198 dnnl_status_t DNNL_API dnnl_concat_primitive_desc_create(
1199         dnnl_primitive_desc_t *concat_primitive_desc,
1200         const dnnl_memory_desc_t *dst_desc, int n, int concat_dimension,
1201         const dnnl_memory_desc_t *src_descs, const_dnnl_primitive_attr_t attr,
1202         dnnl_engine_t engine);
1203 
1204 /// @} dnnl_api_concat
1205 
1206 /// @addtogroup dnnl_api_sum
1207 /// @{
1208 
1209 /// Creates a primitive descriptor for an (out-of-place) sum primitive.
1210 ///
1211 /// @param sum_primitive_desc Output primitive descriptor.
1212 /// @param dst_desc Destination memory descriptor.
1213 /// @param n Number of source parameters.
1214 /// @param scales Vector of scales to multiply data in each source
1215 ///     memory by.
1216 /// @param src_descs Array of source memory descriptors having @p n elements.
1217 /// @param attr Primitive attributes to use (can be NULL).
1218 /// @param engine Engine to use.
1219 /// @returns #dnnl_success on success and a status describing the error
1220 ///     otherwise.
1221 dnnl_status_t DNNL_API dnnl_sum_primitive_desc_create(
1222         dnnl_primitive_desc_t *sum_primitive_desc,
1223         const dnnl_memory_desc_t *dst_desc, int n, const float *scales,
1224         const dnnl_memory_desc_t *src_descs, const_dnnl_primitive_attr_t attr,
1225         dnnl_engine_t engine);
1226 
1227 /// @} dnnl_api_sum
1228 
1229 /// @addtogroup dnnl_api_binary
1230 /// @{
1231 
1232 /// Initializes a descriptor for a binary primitive.
1233 ///
1234 /// @note
1235 ///     Memory descriptor @p dst_desc is allowed to be initialized with
1236 ///     #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
1237 ///
1238 /// @note
1239 ///     Both memory descriptors must have the same number of dimensions.
1240 ///     Element broadcasting is supported for memory descriptor @p src1_desc
1241 ///     and are applied to @ src1_desc dimensions that have size equal to 1.
1242 ///
1243 /// @param binary_desc Output descriptor for a binary primitive.
1244 /// @param alg_kind Algorithm kind. Valid values are #dnnl_binary_add,
1245 ///     #dnnl_binary_mul, #dnnl_binary_max, #dnnl_binary_min, #dnnl_binary_div,
1246 ///     #dnnl_binary_sub, #dnnl_binary_ge, #dnnl_binary_gt, #dnnl_binary_le,
1247 ///     #dnnl_binary_lt, #dnnl_binary_eq and #dnnl_binary_ne.
1248 /// @param src0_desc Source 0 memory descriptor.
1249 /// @param src1_desc Source 1 memory descriptor.
1250 /// @param dst_desc Destination memory descriptor.
1251 /// @returns #dnnl_success on success and a status describing the error
1252 ///     otherwise.
1253 dnnl_status_t DNNL_API dnnl_binary_desc_init(dnnl_binary_desc_t *binary_desc,
1254         dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src0_desc,
1255         const dnnl_memory_desc_t *src1_desc,
1256         const dnnl_memory_desc_t *dst_desc);
1257 
1258 /// @} dnnl_api_binary
1259 
1260 /// @addtogroup dnnl_api_convolution
1261 /// @{
1262 
1263 /// Initializes a descriptor for a convolution forward propagation primitive.
1264 ///
1265 /// @note
1266 ///     Memory descriptors can be initialized with
1267 ///     #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
1268 ///
1269 /// Arrays @p strides, @p padding_l, and @p padding_r contain values for
1270 /// spatial dimensions only and hence must have the same number of elements as
1271 /// there are spatial dimensions. The order of values is the same as in the
1272 /// tensor: depth (for 3D tensors), height (for 3D and 2D tensors), and width.
1273 ///
1274 /// @param conv_desc Output descriptor for a convolution primitive.
1275 /// @param prop_kind Propagation kind. Possible values are
1276 ///     #dnnl_forward_training and #dnnl_forward_inference.
1277 /// @param alg_kind Convolution algorithm. Possible values are
1278 ///     #dnnl_convolution_direct, #dnnl_convolution_winograd,
1279 ///     #dnnl_convolution_auto.
1280 /// @param src_desc Source memory descriptor.
1281 /// @param weights_desc Weights memory descriptor.
1282 /// @param bias_desc Bias memory descriptor. Passing NULL, a zero memory
1283 ///     descriptor, or a memory descriptor with format_kind set to
1284 ///     #dnnl_format_kind_undef disables the bias term.
1285 /// @param dst_desc Destination memory descriptor.
1286 /// @param strides Array of strides for spatial dimension.
1287 /// @param padding_l Array of padding values for low indices for each spatial
1288 ///     dimension `([[front,] top,] left)`.
1289 /// @param padding_r Array of padding values for high indices for each spatial
1290 ///     dimension `([[back,] bottom,] right)`. Can be NULL in which case
1291 ///     padding is assumed to be symmetrical.
1292 /// @returns #dnnl_success on success and a status describing the error
1293 ///     otherwise.
1294 dnnl_status_t DNNL_API dnnl_convolution_forward_desc_init(
1295         dnnl_convolution_desc_t *conv_desc, dnnl_prop_kind_t prop_kind,
1296         dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc,
1297         const dnnl_memory_desc_t *weights_desc,
1298         const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_desc,
1299         const dnnl_dims_t strides, const dnnl_dims_t padding_l,
1300         const dnnl_dims_t padding_r);
1301 
1302 /// Initializes a descriptor for a dilated convolution forward propagation
1303 /// primitive.
1304 ///
1305 /// @note
1306 ///     Memory descriptors can be initialized with
1307 ///     #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
1308 ///
1309 /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r contain
1310 /// values for spatial dimensions only and hence must have the same number of
1311 /// elements as there are spatial dimensions. The order of values is the same
1312 /// as in the tensor: depth (for 3D tensors), height (for 3D and 2D tensors),
1313 /// and width.
1314 ///
1315 /// @param conv_desc Output descriptor for a convolution primitive.
1316 /// @param prop_kind Propagation kind. Possible values are
1317 ///     #dnnl_forward_training and #dnnl_forward_inference.
1318 /// @param alg_kind Convolution algorithm. Possible values are
1319 ///     #dnnl_convolution_direct, #dnnl_convolution_winograd,
1320 ///     #dnnl_convolution_auto.
1321 /// @param src_desc Source memory descriptor.
1322 /// @param weights_desc Weights memory descriptor.
1323 /// @param bias_desc Bias memory descriptor. Passing NULL, a zero memory
1324 ///     descriptor, or a memory descriptor with format_kind set to
1325 ///     #dnnl_format_kind_undef disables the bias term.
1326 /// @param dst_desc Destination memory descriptor.
1327 /// @param strides Array of strides for spatial dimension.
1328 /// @param dilates Array of dilations for spatial dimension. A zero value
1329 ///     means no dilation in the corresponding dimension.
1330 /// @param padding_l Array of padding values for low indices for each spatial
1331 ///     dimension `([[front,] top,] left)`.
1332 /// @param padding_r Array of padding values for high indices for each spatial
1333 ///     dimension `([[back,] bottom,] right)`. Can be NULL in which case
1334 ///     padding is considered to be symmetrical.
1335 /// @returns #dnnl_success on success and a status describing the error
1336 ///     otherwise.
1337 dnnl_status_t DNNL_API dnnl_dilated_convolution_forward_desc_init(
1338         dnnl_convolution_desc_t *conv_desc, dnnl_prop_kind_t prop_kind,
1339         dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc,
1340         const dnnl_memory_desc_t *weights_desc,
1341         const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_desc,
1342         const dnnl_dims_t strides, const dnnl_dims_t dilates,
1343         const dnnl_dims_t padding_l, const dnnl_dims_t padding_r);
1344 
1345 /// Initializes a descriptor for a convolution backward propagation primitive.
1346 ///
1347 /// @note
1348 ///     Memory descriptors can be initialized with
1349 ///     #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
1350 ///
1351 /// Arrays @p strides, @p padding_l, and @p padding_r contain values for
1352 /// spatial dimensions only and hence must have the same number of elements as
1353 /// there are spatial dimensions. The order of values is the same as in the
1354 /// tensor: depth (for 3D tensors), height (for 3D and 2D tensors), and width.
1355 ///
1356 /// @param conv_desc Output descriptor for a convolution primitive.
1357 /// @param alg_kind Convolution algorithm. Possible values are
1358 ///     #dnnl_convolution_direct, #dnnl_convolution_winograd,
1359 ///     #dnnl_convolution_auto.
1360 /// @param diff_src_desc Diff source memory descriptor.
1361 /// @param weights_desc Weights memory descriptor.
1362 /// @param diff_dst_desc Diff destination memory descriptor.
1363 /// @param strides Array of strides for spatial dimension.
1364 /// @param padding_l Array of padding values for low indices for each spatial
1365 ///     dimension `([[front,] top,] left)`.
1366 /// @param padding_r Array of padding values for high indices for each spatial
1367 ///     dimension `([[back,] bottom,] right)`. Can be NULL in which case
1368 ///     padding is assumed to be symmetrical.
1369 /// @returns #dnnl_success on success and a status describing the error
1370 ///     otherwise.
1371 dnnl_status_t DNNL_API dnnl_convolution_backward_data_desc_init(
1372         dnnl_convolution_desc_t *conv_desc, dnnl_alg_kind_t alg_kind,
1373         const dnnl_memory_desc_t *diff_src_desc,
1374         const dnnl_memory_desc_t *weights_desc,
1375         const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides,
1376         const dnnl_dims_t padding_l, const dnnl_dims_t padding_r);
1377 
1378 /// Initializes a descriptor for a dilated convolution backward propagation
1379 /// primitive.
1380 ///
1381 /// @note
1382 ///     Memory descriptors can be initialized with
1383 ///     #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
1384 ///
1385 /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r contain
1386 /// values for spatial dimensions only and hence must have the same number of
1387 /// elements as there are spatial dimensions. The order of values is the same
1388 /// as in the tensor: depth (for 3D tensors), height (for 3D and 2D tensors),
1389 /// and width.
1390 ///
1391 /// @param conv_desc Output descriptor for a convolution primitive.
1392 /// @param alg_kind Convolution algorithm. Possible values are
1393 ///     #dnnl_convolution_direct, #dnnl_convolution_winograd,
1394 ///     #dnnl_convolution_auto.
1395 /// @param diff_src_desc Diff source memory descriptor.
1396 /// @param weights_desc Weights memory descriptor.
1397 /// @param diff_dst_desc Diff destination memory descriptor.
1398 /// @param strides Array of strides for spatial dimension.
1399 /// @param dilates Array of dilations for spatial dimension. A zero value
1400 ///     means no dilation in the corresponding dimension.
1401 /// @param padding_l Array of padding values for low indices for each spatial
1402 ///     dimension `([[front,] top,] left)`.
1403 /// @param padding_r Array of padding values for high indices for each spatial
1404 ///     dimension `([[back,] bottom,] right)`. Can be NULL in which case
1405 ///     padding is considered to be symmetrical.
1406 /// @returns #dnnl_success on success and a status describing the error
1407 ///     otherwise.
1408 dnnl_status_t DNNL_API dnnl_dilated_convolution_backward_data_desc_init(
1409         dnnl_convolution_desc_t *conv_desc, dnnl_alg_kind_t alg_kind,
1410         const dnnl_memory_desc_t *diff_src_desc,
1411         const dnnl_memory_desc_t *weights_desc,
1412         const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides,
1413         const dnnl_dims_t dilates, const dnnl_dims_t padding_l,
1414         const dnnl_dims_t padding_r);
1415 
1416 /// Initializes a descriptor for a convolution weights gradient primitive.
1417 ///
1418 /// @note
1419 ///     Memory descriptors can be initialized with
1420 ///     #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
1421 ///
1422 /// Arrays @p strides, @p padding_l, and @p padding_r contain values for
1423 /// spatial dimensions only and hence must have the same number of elements as
1424 /// there are spatial dimensions. The order of values is the same as in the
1425 /// tensor: depth (for 3D tensors), height (for 3D and 2D tensors), and width.
1426 ///
1427 /// @param conv_desc Output descriptor for a convolution primitive.
1428 /// @param alg_kind Convolution algorithm. Possible values are
1429 ///     #dnnl_convolution_direct, #dnnl_convolution_winograd,
1430 ///     #dnnl_convolution_auto.
1431 /// @param src_desc Source memory descriptor.
1432 /// @param diff_weights_desc Diff weights memory descriptor.
1433 /// @param diff_bias_desc Diff bias memory descriptor. Passing NULL, a zero
1434 ///     memory descriptor, or a memory descriptor with format_kind set to
1435 ///     #dnnl_format_kind_undef disables the bias term.
1436 /// @param diff_dst_desc Diff destination memory descriptor.
1437 /// @param strides Array of strides for spatial dimension.
1438 /// @param padding_l Array of padding values for low indices for each spatial
1439 ///     dimension `([[front,] top,] left)`.
1440 /// @param padding_r Array of padding values for high indices for each spatial
1441 ///     dimension `([[back,] bottom,] right)`. Can be NULL in which case
1442 ///     padding is considered to be symmetrical.
1443 /// @returns #dnnl_success on success and a status describing the error
1444 ///     otherwise.
1445 dnnl_status_t DNNL_API dnnl_convolution_backward_weights_desc_init(
1446         dnnl_convolution_desc_t *conv_desc, dnnl_alg_kind_t alg_kind,
1447         const dnnl_memory_desc_t *src_desc,
1448         const dnnl_memory_desc_t *diff_weights_desc,
1449         const dnnl_memory_desc_t *diff_bias_desc,
1450         const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides,
1451         const dnnl_dims_t padding_l, const dnnl_dims_t padding_r);
1452 
1453 /// Initializes a descriptor for a dilated convolution weights gradient
1454 /// primitive.
1455 ///
1456 /// @note
1457 ///     Memory descriptors can be initialized with
1458 ///     #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
1459 ///
1460 /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r contain
1461 /// values for spatial dimensions only and hence must have the same number of
1462 /// elements as there are spatial dimensions. The order of values is the same
1463 /// as in the tensor: depth (for 3D tensors), height (for 3D and 2D tensors),
1464 /// and width.
1465 ///
1466 /// @param conv_desc Output descriptor for a convolution primitive.
1467 /// @param alg_kind Convolution algorithm. Possible values are
1468 ///     #dnnl_convolution_direct, #dnnl_convolution_winograd,
1469 ///     #dnnl_convolution_auto.
1470 /// @param src_desc Source memory descriptor.
1471 /// @param diff_weights_desc Diff weights memory descriptor.
1472 /// @param diff_bias_desc Diff bias memory descriptor. Passing NULL, a zero
1473 ///     memory descriptor, or a memory descriptor with format_kind set to
1474 ///     #dnnl_format_kind_undef disables the bias term.
1475 /// @param diff_dst_desc Diff destination memory descriptor.
1476 /// @param strides Array of strides for spatial dimension.
1477 /// @param dilates Array of dilations for spatial dimension. A zero value
1478 ///     means no dilation in the corresponding dimension.
1479 /// @param padding_l Array of padding values for low indices for each spatial
1480 ///     dimension `([[front,] top,] left)`.
1481 /// @param padding_r Array of padding values for high indices for each spatial
1482 ///     dimension `([[back,] bottom,] right)`. Can be NULL in which case
1483 ///     padding is considered to be symmetrical.
1484 /// @returns #dnnl_success on success and a status describing the error
1485 ///     otherwise.
1486 dnnl_status_t DNNL_API dnnl_dilated_convolution_backward_weights_desc_init(
1487         dnnl_convolution_desc_t *conv_desc, dnnl_alg_kind_t alg_kind,
1488         const dnnl_memory_desc_t *src_desc,
1489         const dnnl_memory_desc_t *diff_weights_desc,
1490         const dnnl_memory_desc_t *diff_bias_desc,
1491         const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides,
1492         const dnnl_dims_t dilates, const dnnl_dims_t padding_l,
1493         const dnnl_dims_t padding_r);
1494 
1495 /// @} dnnl_api_convolution
1496 
1497 /// @addtogroup dnnl_api_deconvolution
1498 /// @{
1499 
1500 /// Initializes a descriptor for a deconvolution forward propagation primitive.
1501 ///
1502 /// @note
1503 ///     Memory descriptors can be initialized with
1504 ///     #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
1505 ///
1506 /// Arrays @p strides, @p padding_l, and @p padding_r contain values for
1507 /// spatial dimensions only and hence must have the same number of elements as
1508 /// there are spatial dimensions. The order of values is the same as in the
1509 /// tensor: depth (for 3D tensors), height (for 3D and 2D tensors), and width.
1510 ///
1511 /// @param deconv_desc Output descriptor for a deconvolution primitive.
1512 /// @param prop_kind Propagation kind. Possible values are
1513 ///     #dnnl_forward_training and #dnnl_forward_inference.
1514 /// @param alg_kind Deconvolution algorithm. Possible values are
1515 ///     #dnnl_deconvolution_direct, #dnnl_deconvolution_winograd.
1516 /// @param src_desc Source memory descriptor.
1517 /// @param weights_desc Weights memory descriptor.
1518 /// @param bias_desc Bias memory descriptor. Passing NULL, a zero memory
1519 ///     descriptor, or a memory descriptor with format_kind set to
1520 ///     #dnnl_format_kind_undef disables the bias term.
1521 /// @param dst_desc Destination memory descriptor.
1522 /// @param strides Array of strides for spatial dimension.
1523 /// @param padding_l Array of padding values for low indices for each spatial
1524 ///     dimension `([[front,] top,] left)`.
1525 /// @param padding_r Array of padding values for high indices for each spatial
1526 ///     dimension `([[back,] bottom,] right)`. Can be NULL in which case
1527 ///     padding is considered to be symmetrical.
1528 /// @returns #dnnl_success on success and a status describing the error
1529 ///     otherwise.
1530 dnnl_status_t DNNL_API dnnl_deconvolution_forward_desc_init(
1531         dnnl_deconvolution_desc_t *deconv_desc, dnnl_prop_kind_t prop_kind,
1532         dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc,
1533         const dnnl_memory_desc_t *weights_desc,
1534         const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_desc,
1535         const dnnl_dims_t strides, const dnnl_dims_t padding_l,
1536         const dnnl_dims_t padding_r);
1537 
1538 /// Initializes a descriptor for a dilated deconvolution forward propagation
1539 /// primitive.
1540 ///
1541 /// @note
1542 ///     Memory descriptors can be initialized with
1543 ///     #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
1544 ///
1545 /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r contain
1546 /// values for spatial dimensions only and hence must have the same number of
1547 /// elements as there are spatial dimensions. The order of values is the same
1548 /// as in the tensor: depth (for 3D tensors), height (for 3D and 2D tensors),
1549 /// and width.
1550 ///
1551 /// @param deconv_desc Output descriptor for a deconvolution primitive.
1552 /// @param prop_kind Propagation kind. Possible values are
1553 ///     #dnnl_forward_training and #dnnl_forward_inference.
1554 /// @param alg_kind Deconvolution algorithm. Possible values are
1555 ///     #dnnl_deconvolution_direct, #dnnl_deconvolution_winograd.
1556 /// @param src_desc Source memory descriptor.
1557 /// @param weights_desc Weights memory descriptor.
1558 /// @param bias_desc Bias memory descriptor. Passing NULL, a zero memory
1559 ///     descriptor, or a memory descriptor with format_kind set to
1560 ///     #dnnl_format_kind_undef disables the bias term.
1561 /// @param dst_desc Destination memory descriptor.
1562 /// @param strides Array of strides for spatial dimension.
1563 /// @param dilates Array of dilations for spatial dimension. A zero value
1564 ///     means no dilation in the corresponding dimension.
1565 /// @param padding_l Array of padding values for low indices for each spatial
1566 ///     dimension `([[front,] top,] left)`.
1567 /// @param padding_r Array of padding values for high indices for each spatial
1568 ///     dimension `([[back,] bottom,] right)`. Can be NULL in which case
1569 ///     padding is considered to be symmetrical.
1570 /// @returns #dnnl_success on success and a status describing the error
1571 ///     otherwise.
1572 dnnl_status_t DNNL_API dnnl_dilated_deconvolution_forward_desc_init(
1573         dnnl_deconvolution_desc_t *deconv_desc, dnnl_prop_kind_t prop_kind,
1574         dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc,
1575         const dnnl_memory_desc_t *weights_desc,
1576         const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_desc,
1577         const dnnl_dims_t strides, const dnnl_dims_t dilates,
1578         const dnnl_dims_t padding_l, const dnnl_dims_t padding_r);
1579 
1580 /// Initializes a descriptor for a deconvolution backward propagation primitive.
1581 ///
1582 /// @note
1583 ///     Memory descriptors can be initialized with
1584 ///     #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
1585 ///
1586 /// Arrays @p strides, @p padding_l, and @p padding_r contain values for
1587 /// spatial dimensions only and hence must have the same number of elements as
1588 /// there are spatial dimensions. The order of values is the same as in the
1589 /// tensor: depth (for 3D tensors), height (for 3D and 2D tensors), and width.
1590 ///
1591 /// @param deconv_desc Output descriptor for a deconvolution primitive.
1592 /// @param alg_kind Deconvolution algorithm. Possible values are
1593 ///     #dnnl_deconvolution_direct, #dnnl_deconvolution_winograd.
1594 /// @param diff_src_desc Diff source memory descriptor.
1595 /// @param weights_desc Weights memory descriptor.
1596 /// @param diff_dst_desc Diff destination memory descriptor.
1597 /// @param strides Array of strides for spatial dimension.
1598 /// @param padding_l Array of padding values for low indices for each spatial
1599 ///     dimension `([[front,] top,] left)`.
1600 /// @param padding_r Array of padding values for high indices for each spatial
1601 ///     dimension `([[back,] bottom,] right)`. Can be NULL in which case
1602 ///     padding is considered to be symmetrical.
1603 /// @returns #dnnl_success on success and a status describing the error
1604 ///     otherwise.
1605 dnnl_status_t DNNL_API dnnl_deconvolution_backward_data_desc_init(
1606         dnnl_deconvolution_desc_t *deconv_desc, dnnl_alg_kind_t alg_kind,
1607         const dnnl_memory_desc_t *diff_src_desc,
1608         const dnnl_memory_desc_t *weights_desc,
1609         const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides,
1610         const dnnl_dims_t padding_l, const dnnl_dims_t padding_r);
1611 
1612 /// Initializes a descriptor for a dilated deconvolution backward propagation
1613 /// primitive.
1614 ///
1615 /// @note
1616 ///     Memory descriptors can be initialized with
1617 ///     #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
1618 ///
1619 /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r contain
1620 /// values for spatial dimensions only and hence must have the same number of
1621 /// elements as there are spatial dimensions. The order of values is the same
1622 /// as in the tensor: depth (for 3D tensors), height (for 3D and 2D tensors),
1623 /// and width.
1624 ///
1625 /// @param deconv_desc Output descriptor for a deconvolution primitive.
1626 /// @param alg_kind Deconvolution algorithm. Possible values are
1627 ///     #dnnl_deconvolution_direct, #dnnl_deconvolution_winograd.
1628 /// @param diff_src_desc Diff source memory descriptor.
1629 /// @param weights_desc Weights memory descriptor.
1630 /// @param diff_dst_desc Diff destination memory descriptor.
1631 /// @param strides Array of strides for spatial dimension.
1632 /// @param dilates Array of dilations for spatial dimension. A zero value
1633 ///     means no dilation in the corresponding dimension.
1634 /// @param padding_l Array of padding values for low indices for each spatial
1635 ///     dimension `([[front,] top,] left)`.
1636 /// @param padding_r Array of padding values for high indices for each spatial
1637 ///     dimension `([[back,] bottom,] right)`. Can be NULL in which case
1638 ///     padding is considered to be symmetrical.
1639 /// @returns #dnnl_success on success and a status describing the error
1640 ///     otherwise.
1641 dnnl_status_t DNNL_API dnnl_dilated_deconvolution_backward_data_desc_init(
1642         dnnl_deconvolution_desc_t *deconv_desc, dnnl_alg_kind_t alg_kind,
1643         const dnnl_memory_desc_t *diff_src_desc,
1644         const dnnl_memory_desc_t *weights_desc,
1645         const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides,
1646         const dnnl_dims_t dilates, const dnnl_dims_t padding_l,
1647         const dnnl_dims_t padding_r);
1648 
1649 /// Initializes a descriptor for a deconvolution weights gradient primitive.
1650 ///
1651 /// @note
1652 ///     Memory descriptors can be initialized with
1653 ///     #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
1654 ///
1655 /// Arrays @p strides, @p padding_l, and @p padding_r contain values for
1656 /// spatial dimensions only and hence must have the same number of elements as
1657 /// there are spatial dimensions. The order of values is the same as in the
1658 /// tensor: depth (for 3D tensors), height (for 3D and 2D tensors), and width.
1659 ///
1660 /// @param deconv_desc Output descriptor for a deconvolution primitive.
1661 /// @param alg_kind Deconvolution algorithm. Possible values are
1662 ///     #dnnl_deconvolution_direct, #dnnl_deconvolution_winograd.
1663 /// @param src_desc Source memory descriptor.
1664 /// @param diff_weights_desc Diff weights memory descriptor.
1665 /// @param diff_bias_desc Diff bias memory descriptor. Passing NULL, a zero
1666 ///     memory descriptor, or a memory descriptor with format_kind set to
1667 ///     #dnnl_format_kind_undef disables the bias term.
1668 /// @param diff_dst_desc Diff destination memory descriptor.
1669 /// @param strides Array of strides for spatial dimension.
1670 /// @param padding_l Array of padding values for low indices for each spatial
1671 ///     dimension `([[front,] top,] left)`.
1672 /// @param padding_r Array of padding values for high indices for each spatial
1673 ///     dimension `([[back,] bottom,] right)`. Can be NULL in which case
1674 ///     padding is considered to be symmetrical.
1675 /// @returns #dnnl_success on success and a status describing the error
1676 ///     otherwise.
1677 dnnl_status_t DNNL_API dnnl_deconvolution_backward_weights_desc_init(
1678         dnnl_deconvolution_desc_t *deconv_desc, dnnl_alg_kind_t alg_kind,
1679         const dnnl_memory_desc_t *src_desc,
1680         const dnnl_memory_desc_t *diff_weights_desc,
1681         const dnnl_memory_desc_t *diff_bias_desc,
1682         const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides,
1683         const dnnl_dims_t padding_l, const dnnl_dims_t padding_r);
1684 
1685 /// Initializes a descriptor for a dilated deconvolution weights gradient
1686 /// primitive.
1687 ///
1688 /// @note
1689 ///     Memory descriptors can be initialized with
1690 ///     #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
1691 ///
1692 /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r contain
1693 /// values for spatial dimensions only and hence must have the same number of
1694 /// elements as there are spatial dimensions. The order of values is the same
1695 /// as in the tensor: depth (for 3D tensors), height (for 3D and 2D tensors),
1696 /// and width.
1697 ///
1698 /// @param deconv_desc Output descriptor for a deconvolution primitive.
1699 /// @param alg_kind Deconvolution algorithm. Possible values are
1700 ///     #dnnl_deconvolution_direct, #dnnl_deconvolution_winograd.
1701 /// @param src_desc Source memory descriptor.
1702 /// @param diff_weights_desc Diff weights memory descriptor.
1703 /// @param diff_bias_desc Diff bias memory descriptor. Passing NULL, a zero
1704 ///     memory descriptor, or a memory descriptor with format_kind set to
1705 ///     #dnnl_format_kind_undef disables the bias term.
1706 /// @param diff_dst_desc Diff destination memory descriptor.
1707 /// @param strides Array of strides for spatial dimension.
1708 /// @param dilates Array of dilations for spatial dimension. A zero value
1709 ///     means no dilation in the corresponding dimension.
1710 /// @param padding_l Array of padding values for low indices for each spatial
1711 ///     dimension `([[front,] top,] left)`.
1712 /// @param padding_r Array of padding values for high indices for each spatial
1713 ///     dimension `([[back,] bottom,] right)`. Can be NULL in which case
1714 ///     padding is considered to be symmetrical.
1715 /// @returns #dnnl_success on success and a status describing the error
1716 ///     otherwise.
1717 dnnl_status_t DNNL_API dnnl_dilated_deconvolution_backward_weights_desc_init(
1718         dnnl_deconvolution_desc_t *deconv_desc, dnnl_alg_kind_t alg_kind,
1719         const dnnl_memory_desc_t *src_desc,
1720         const dnnl_memory_desc_t *diff_weights_desc,
1721         const dnnl_memory_desc_t *diff_bias_desc,
1722         const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides,
1723         const dnnl_dims_t dilates, const dnnl_dims_t padding_l,
1724         const dnnl_dims_t padding_r);
1725 
1726 /// @} dnnl_api_deconvolution
1727 
1728 /// @addtogroup dnnl_api_shuffle
1729 /// @{
1730 
1731 /// Initializes a descriptor for shuffle forward propagation primitive.
1732 ///
1733 /// @param shuffle_desc Output descriptor for a shuffle primitive.
1734 /// @param prop_kind Propagation kind. Possible values are
1735 ///     #dnnl_forward_training and #dnnl_forward_inference.
1736 /// @param data_desc Source and destination memory descriptor.
1737 /// @param axis The axis along which the data is shuffled.
1738 /// @param group_size Shuffle group size.
1739 /// @returns #dnnl_success on success and a status describing the error
1740 ///     otherwise.
1741 dnnl_status_t DNNL_API dnnl_shuffle_forward_desc_init(
1742         dnnl_shuffle_desc_t *shuffle_desc, dnnl_prop_kind_t prop_kind,
1743         const dnnl_memory_desc_t *data_desc, int axis, dnnl_dim_t group_size);
1744 
1745 /// Initializes a descriptor for shuffle backward propagation primitive.
1746 ///
1747 /// @param shuffle_desc Output descriptor for a shuffle primitive.
1748 /// @param diff_data_desc Diff source and diff destination memory descriptor.
1749 /// @param axis The axis along which the data is shuffled.
1750 /// @param group_size Shuffle group size.
1751 /// @returns #dnnl_success on success and a status describing the error
1752 ///     otherwise.
1753 dnnl_status_t DNNL_API dnnl_shuffle_backward_desc_init(
1754         dnnl_shuffle_desc_t *shuffle_desc,
1755         const dnnl_memory_desc_t *diff_data_desc, int axis,
1756         dnnl_dim_t group_size);
1757 
1758 /// @} dnnl_api_shuffle
1759 
1760 /// @addtogroup dnnl_api_eltwise
1761 /// @{
1762 
1763 /// Initializes a descriptor for eltwise forward propagation primitive.
1764 ///
1765 /// @param eltwise_desc Output descriptor for an eltwise primitive.
1766 /// @param prop_kind Propagation kind. Possible values are
1767 ///     #dnnl_forward_training and #dnnl_forward_inference.
1768 /// @param alg_kind Elementwise algorithm kind.
1769 /// @param data_desc Source and destination memory descriptor.
1770 /// @param alpha The alpha parameter for the elementwise operation. Specific
1771 ///     meaning depends on the algorithm.
1772 /// @param beta The beta parameter for the elementwise operation. Specific
1773 ///     meaning depends on the algorithm.
1774 /// @returns #dnnl_success on success and a status describing the error
1775 ///     otherwise.
1776 dnnl_status_t DNNL_API dnnl_eltwise_forward_desc_init(
1777         dnnl_eltwise_desc_t *eltwise_desc, dnnl_prop_kind_t prop_kind,
1778         dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *data_desc,
1779         float alpha, float beta);
1780 
1781 /// Initializes a descriptor for eltwise backward propagation primitive.
1782 ///
1783 /// @param eltwise_desc Output descriptor for an eltwise primitive.
1784 /// @param alg_kind Elementwise algorithm kind.
1785 /// @param diff_data_desc Diff source and diff destination memory descriptors.
1786 /// @param data_desc Source and destination memory descriptor.
1787 /// @param alpha The alpha parameter for the elementwise operation. Specific
1788 ///     meaning depends on the algorithm.
1789 /// @param beta The beta parameter for the elementwise operation. Specific
1790 ///     meaning depends on the algorithm.
1791 /// @returns #dnnl_success on success and a status describing the error
1792 ///     otherwise.
1793 dnnl_status_t DNNL_API dnnl_eltwise_backward_desc_init(
1794         dnnl_eltwise_desc_t *eltwise_desc, dnnl_alg_kind_t alg_kind,
1795         const dnnl_memory_desc_t *diff_data_desc,
1796         const dnnl_memory_desc_t *data_desc, float alpha, float beta);
1797 
1798 /// @} dnnl_api_eltwise
1799 
1800 /// @addtogroup dnnl_api_softmax
1801 /// @{
1802 
1803 /// Initializes a descriptor for softmax forward propagation primitive.
1804 ///
1805 /// @param softmax_desc Output descriptor for a softmax primitive.
1806 /// @param prop_kind Propagation kind. Possible values are
1807 ///     #dnnl_forward_training and #dnnl_forward_inference.
1808 /// @param data_desc Source and destination memory descriptor.
1809 /// @param softmax_axis Axis over which softmax is computed.
1810 /// @returns #dnnl_success on success and a status describing the error
1811 ///     otherwise.
1812 dnnl_status_t DNNL_API dnnl_softmax_forward_desc_init(
1813         dnnl_softmax_desc_t *softmax_desc, dnnl_prop_kind_t prop_kind,
1814         const dnnl_memory_desc_t *data_desc, int softmax_axis);
1815 
1816 /// Initializes a descriptor for softmax backward propagation primitive.
1817 ///
1818 /// @param softmax_desc Output descriptor for a softmax primitive.
1819 /// @param diff_data_desc Diff source and diff destination memory descriptors.
1820 /// @param data_desc Destination memory descriptor.
1821 /// @param softmax_axis Axis over which softmax is computed.
1822 /// @returns #dnnl_success on success and a status describing the error
1823 ///     otherwise.
1824 dnnl_status_t DNNL_API dnnl_softmax_backward_desc_init(
1825         dnnl_softmax_desc_t *softmax_desc,
1826         const dnnl_memory_desc_t *diff_data_desc,
1827         const dnnl_memory_desc_t *data_desc, int softmax_axis);
1828 
1829 /// @} dnnl_api_softmax
1830 
1831 /// @addtogroup dnnl_api_logsoftmax
1832 /// @{
1833 
1834 /// Initializes a descriptor for logsoftmax forward propagation primitive.
1835 ///
1836 /// @param logsoftmax_desc Output descriptor for a logsoftmax primitive.
1837 /// @param prop_kind Propagation kind. Possible values are
1838 ///     #dnnl_forward_training and #dnnl_forward_inference.
1839 /// @param data_desc Source and destination memory descriptor.
1840 /// @param logsoftmax_axis Axis over which logsoftmax is computed.
1841 /// @returns #dnnl_success on success and a status describing the error
1842 ///     otherwise.
1843 dnnl_status_t DNNL_API dnnl_logsoftmax_forward_desc_init(
1844         dnnl_logsoftmax_desc_t *logsoftmax_desc, dnnl_prop_kind_t prop_kind,
1845         const dnnl_memory_desc_t *data_desc, int logsoftmax_axis);
1846 
1847 /// Initializes a descriptor for logsoftmax backward propagation primitive.
1848 ///
1849 /// @param logsoftmax_desc Output descriptor for a logsoftmax primitive.
1850 /// @param diff_data_desc Diff source and diff destination memory descriptors.
1851 /// @param data_desc Destination memory descriptor.
1852 /// @param logsoftmax_axis Axis over which softmax is computed.
1853 /// @returns #dnnl_success on success and a status describing the error
1854 ///     otherwise.
1855 dnnl_status_t DNNL_API dnnl_logsoftmax_backward_desc_init(
1856         dnnl_logsoftmax_desc_t *logsoftmax_desc,
1857         const dnnl_memory_desc_t *diff_data_desc,
1858         const dnnl_memory_desc_t *data_desc, int logsoftmax_axis);
1859 
1860 /// @} dnnl_api_logsoftmax
1861 
1862 /// @addtogroup dnnl_api_pooling
1863 /// @{
1864 
1865 /// Initializes a descriptor for pooling forward propagation primitive.
1866 ///
1867 /// Arrays @p strides, @p kernel, @p padding_l, and @p padding_r contain values
1868 /// for spatial dimensions only and hence must have the same number of elements
1869 /// as there are spatial dimensions. The order of values is the same as in the
1870 /// tensor: depth (for 3D tensors), height (for 3D and 2D tensors), and width.
1871 ///
1872 /// @param pool_desc Output descriptor for a pooling primitive.
1873 /// @param prop_kind Propagation kind. Possible values are
1874 ///     #dnnl_forward_training and #dnnl_forward_inference.
1875 /// @param alg_kind Pooling algorithm kind: either #dnnl_pooling_max,
1876 ///     #dnnl_pooling_avg_include_padding, or #dnnl_pooling_avg (same as
1877 ///     #dnnl_pooling_avg_exclude_padding).
1878 /// @param src_desc Source memory descriptor.
1879 /// @param dst_desc Destination memory descriptor.
1880 /// @param strides Array of strides for spatial dimension.
1881 /// @param kernel Array of kernel spatial dimensions.
1882 /// @param padding_l Array of padding values for low indices for each spatial
1883 ///     dimension `([[front,] top,] left)`.
1884 /// @param padding_r Array of padding values for high indices for each spatial
1885 ///     dimension `([[back,] bottom,] right)`. Can be NULL in which case
1886 ///     padding is considered to be symmetrical.
1887 /// @returns #dnnl_success on success and a status describing the error
1888 ///     otherwise.
1889 dnnl_status_t DNNL_API dnnl_pooling_forward_desc_init(
1890         dnnl_pooling_desc_t *pool_desc, dnnl_prop_kind_t prop_kind,
1891         dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc,
1892         const dnnl_memory_desc_t *dst_desc, const dnnl_dims_t strides,
1893         const dnnl_dims_t kernel, const dnnl_dims_t padding_l,
1894         const dnnl_dims_t padding_r);
1895 
1896 /// Initializes a descriptor for pooling backward propagation primitive.
1897 ///
1898 /// Arrays @p strides, @p kernel, @p padding_l, and @p padding_r contain values
1899 /// for spatial dimensions only and hence must have the same number of elements
1900 /// as there are spatial dimensions. The order of values is the same as in the
1901 /// tensor: depth (for 3D tensors), height (for 3D and 2D tensors), and width.
1902 ///
1903 /// @param pool_desc Output descriptor for a pooling primitive.
1904 /// @param alg_kind Pooling algorithm kind: either #dnnl_pooling_max,
1905 ///     #dnnl_pooling_avg_include_padding, or #dnnl_pooling_avg (same as
1906 ///     #dnnl_pooling_avg_exclude_padding).
1907 /// @param diff_src_desc Diff source memory descriptor.
1908 /// @param diff_dst_desc Diff destination memory descriptor.
1909 /// @param strides Array of strides for spatial dimension.
1910 /// @param kernel Array of kernel spatial dimensions.
1911 /// @param padding_l Array of padding values for low indices for each spatial
1912 ///     dimension `([[front,] top,] left)`.
1913 /// @param padding_r Array of padding values for high indices for each spatial
1914 ///     dimension `([[back,] bottom,] right)`. Can be NULL in which case
1915 ///     padding is considered to be symmetrical.
1916 /// @returns #dnnl_success on success and a status describing the error
1917 ///     otherwise.
1918 dnnl_status_t DNNL_API dnnl_pooling_backward_desc_init(
1919         dnnl_pooling_desc_t *pool_desc, dnnl_alg_kind_t alg_kind,
1920         const dnnl_memory_desc_t *diff_src_desc,
1921         const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides,
1922         const dnnl_dims_t kernel, const dnnl_dims_t padding_l,
1923         const dnnl_dims_t padding_r);
1924 
1925 /// @} dnnl_api_pooling
1926 
1927 /// @addtogroup dnnl_api_pooling_v2
1928 /// @{
1929 
1930 /// Initializes a descriptor for pooling v2 (pooling with dilation support)
1931 /// forward propagation primitive.
1932 ///
1933 /// Arrays @p strides, @p kernel, @p dilation, @p padding_l and @p padding_r
1934 /// contain values for spatial dimensions only and hence must have the same
1935 /// number of elements as there are spatial dimensions. The order of values
1936 /// is the same as in the tensor: depth (for 3D tensors),
1937 /// height (for 3D and 2D tensors), and width.
1938 ///
1939 /// @param pool_desc Output descriptor for a pooling primitive.
1940 /// @param prop_kind Propagation kind. Possible values are
1941 ///     #dnnl_forward_training and #dnnl_forward_inference.
1942 /// @param alg_kind Pooling algorithm kind: either #dnnl_pooling_max,
1943 ///     #dnnl_pooling_avg_include_padding, or #dnnl_pooling_avg (same as
1944 ///     #dnnl_pooling_avg_exclude_padding).
1945 /// @param src_desc Source memory descriptor.
1946 /// @param dst_desc Destination memory descriptor.
1947 /// @param strides Array of strides for spatial dimension.
1948 /// @param kernel Array of kernel spatial dimensions.
1949 /// @param dilation Array of dilations for spatial dimension.
1950 /// @param padding_l Array of padding values for low indices for each spatial
1951 ///     dimension `([[front,] top,] left)`.
1952 /// @param padding_r Array of padding values for high indices for each spatial
1953 ///     dimension `([[back,] bottom,] right)`. Can be NULL in which case
1954 ///     padding is considered to be symmetrical.
1955 /// @returns #dnnl_success on success and a status describing the error
1956 ///     otherwise.
1957 dnnl_status_t DNNL_API dnnl_pooling_v2_forward_desc_init(
1958         dnnl_pooling_v2_desc_t *pool_desc, dnnl_prop_kind_t prop_kind,
1959         dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc,
1960         const dnnl_memory_desc_t *dst_desc, const dnnl_dims_t strides,
1961         const dnnl_dims_t kernel, const dnnl_dims_t dilation,
1962         const dnnl_dims_t padding_l, const dnnl_dims_t padding_r);
1963 
1964 /// Initializes a descriptor for pooling v2 (pooling with dilation support)
1965 /// backward propagation primitive.
1966 ///
1967 /// Arrays @p strides, @p kernel, @p dilation, @p padding_l and @p padding_r
1968 /// contain values for spatial dimensions only and hence must have the same
1969 /// number of elements as there are spatial dimensions. The order of values
1970 /// is the same as in the tensor: depth (for 3D tensors),
1971 /// height (for 3D and 2D tensors), and width.
1972 ///
1973 /// @param pool_desc Output descriptor for a pooling primitive.
1974 /// @param alg_kind Pooling algorithm kind: either #dnnl_pooling_max,
1975 ///     #dnnl_pooling_avg_include_padding, or #dnnl_pooling_avg (same as
1976 ///     #dnnl_pooling_avg_exclude_padding).
1977 /// @param diff_src_desc Diff source memory descriptor.
1978 /// @param diff_dst_desc Diff destination memory descriptor.
1979 /// @param strides Array of strides for spatial dimension.
1980 /// @param kernel Array of kernel spatial dimensions.
1981 /// @param dilation Array of dilations for spatial dimension.
1982 /// @param padding_l Array of padding values for low indices for each spatial
1983 ///     dimension `([[front,] top,] left)`.
1984 /// @param padding_r Array of padding values for high indices for each spatial
1985 ///     dimension `([[back,] bottom,] right)`. Can be NULL in which case
1986 ///     padding is considered to be symmetrical.
1987 /// @returns #dnnl_success on success and a status describing the error
1988 ///     otherwise.
1989 dnnl_status_t DNNL_API dnnl_pooling_v2_backward_desc_init(
1990         dnnl_pooling_v2_desc_t *pool_desc, dnnl_alg_kind_t alg_kind,
1991         const dnnl_memory_desc_t *diff_src_desc,
1992         const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides,
1993         const dnnl_dims_t kernel, const dnnl_dims_t dilation,
1994         const dnnl_dims_t padding_l, const dnnl_dims_t padding_r);
1995 
1996 /// @} dnnl_api_pooling_v2
1997 
1998 /// @addtogroup dnnl_api_prelu
1999 /// @{
2000 
2001 /// Initializes a descriptor for PReLU
2002 /// (leaky ReLU with trainable alpha parameter)
2003 /// forward propagation primitive.
2004 ///
2005 /// @note
2006 ///     weights descriptor is allowed to be initialized with
2007 ///     #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
2008 ///
2009 /// @param prelu_desc Output descriptor for a prelu primitive.
2010 /// @param prop_kind Propagation kind. Possible values are
2011 ///     #dnnl_forward_training and #dnnl_forward_inference.
2012 /// @param data_desc Source and destination memory descriptor.
2013 /// @param weights_desc Alpha parameters memory descriptor.
2014 /// @returns #dnnl_success on success and a status describing the error
2015 ///     otherwise.
2016 dnnl_status_t DNNL_API dnnl_prelu_forward_desc_init(
2017         dnnl_prelu_desc_t *prelu_desc, dnnl_prop_kind_t prop_kind,
2018         const dnnl_memory_desc_t *data_desc,
2019         const dnnl_memory_desc_t *weights_desc);
2020 
2021 /// Initializes a descriptor for PReLU
2022 /// (leaky ReLU with trainable alpha parameter)
2023 /// backward propagation primitive.
2024 ///
2025 /// @note
2026 ///     weights descriptor and diff_weights descriptor are allowed
2027 ///     to be initialized with #dnnl_format_tag_any or with format_kind
2028 ///     set to #dnnl_format_kind_any.
2029 ///
2030 /// @param prelu_desc Output descriptor for a prelu primitive.
2031 /// @param data_desc Source and destination memory descriptor.
2032 /// @param weights_desc Alpha parameters memory descriptor.
2033 /// @param diff_data_desc Diff source and destination memory descriptor.
2034 /// @param diff_weights_desc Diff alpha parameters memory descriptor.
2035 /// @returns #dnnl_success on success and a status describing the error
2036 ///     otherwise.
2037 dnnl_status_t DNNL_API dnnl_prelu_backward_desc_init(
2038         dnnl_prelu_desc_t *prelu_desc, const dnnl_memory_desc_t *data_desc,
2039         const dnnl_memory_desc_t *weights_desc,
2040         const dnnl_memory_desc_t *diff_data_desc,
2041         const dnnl_memory_desc_t *diff_weights_desc);
2042 
2043 /// @} dnnl_api_prelu
2044 
2045 /// @addtogroup dnnl_api_lrn
2046 /// @{
2047 
2048 /// Initializes a descriptor for LRN forward propagation primitive.
2049 ///
2050 /// @param lrn_desc Output descriptor for a LRN primitive.
2051 /// @param prop_kind Propagation kind. Possible values are
2052 ///     #dnnl_forward_training and #dnnl_forward_inference.
2053 /// @param alg_kind LRN algorithm kind: either #dnnl_lrn_across_channels or
2054 ///     #dnnl_lrn_within_channel.
2055 /// @param data_desc Source and destination memory descriptor.
2056 /// @param local_size Regularization local size.
2057 /// @param alpha The alpha regularization parameter.
2058 /// @param beta The beta regularization parameter.
2059 /// @param k The k regularization parameter.
2060 /// @returns #dnnl_success on success and a status describing the error
2061 ///     otherwise.
2062 dnnl_status_t DNNL_API dnnl_lrn_forward_desc_init(dnnl_lrn_desc_t *lrn_desc,
2063         dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind,
2064         const dnnl_memory_desc_t *data_desc, dnnl_dim_t local_size, float alpha,
2065         float beta, float k);
2066 
2067 /// Initializes a descriptor for LRN backward propagation primitive.
2068 ///
2069 /// @param lrn_desc Output descriptor for a LRN primitive.
2070 /// @param alg_kind LRN algorithm kind: either #dnnl_lrn_across_channels or
2071 ///     #dnnl_lrn_within_channel.
2072 /// @param diff_data_desc Diff source and diff destination memory descriptor.
2073 /// @param data_desc Source memory descriptor.
2074 /// @param local_size Regularization local size.
2075 /// @param alpha The alpha regularization parameter.
2076 /// @param beta The beta regularization parameter.
2077 /// @param k The k regularization parameter.
2078 /// @returns #dnnl_success on success and a status describing the error
2079 ///     otherwise.
2080 dnnl_status_t DNNL_API dnnl_lrn_backward_desc_init(dnnl_lrn_desc_t *lrn_desc,
2081         dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_data_desc,
2082         const dnnl_memory_desc_t *data_desc, dnnl_dim_t local_size, float alpha,
2083         float beta, float k);
2084 
2085 /// @} dnnl_api_lrn
2086 
2087 /// @addtogroup dnnl_api_batch_normalization
2088 /// @{
2089 
2090 /// Initializes a descriptor for a batch normalization forward propagation
2091 /// primitive.
2092 ///
2093 /// @note
2094 ///     In-place operation is supported: the dst can refer to the same memory
2095 ///     as the src.
2096 ///
2097 /// @param bnrm_desc Output descriptor for batch normalization primitive.
2098 /// @param prop_kind Propagation kind. Possible values are
2099 ///     #dnnl_forward_training and #dnnl_forward_inference.
2100 /// @param data_desc Source and destination memory descriptor.
2101 /// @param epsilon Batch normalization epsilon parameter.
2102 /// @param flags Batch normalization flags (@ref dnnl_normalization_flags_t).
2103 /// @returns #dnnl_success on success and a status describing the error
2104 ///     otherwise.
2105 dnnl_status_t DNNL_API dnnl_batch_normalization_forward_desc_init(
2106         dnnl_batch_normalization_desc_t *bnrm_desc, dnnl_prop_kind_t prop_kind,
2107         const dnnl_memory_desc_t *data_desc, float epsilon, unsigned flags);
2108 
2109 /// Initializes a descriptor for a batch normalization backward propagation
2110 /// primitive.
2111 ///
2112 /// @note
2113 ///     In-place operation is supported: the diff_dst can refer to the same
2114 ///     memory as the diff_src.
2115 ///
2116 /// @param bnrm_desc Output descriptor for batch normalization primitive.
2117 /// @param prop_kind Propagation kind. Possible values are
2118 ///     #dnnl_backward_data and #dnnl_backward (diffs for all parameters are
2119 ///     computed in this case).
2120 /// @param diff_data_desc Diff source and diff destination memory descriptor.
2121 /// @param data_desc Source memory descriptor.
2122 /// @param epsilon Batch normalization epsilon parameter.
2123 /// @param flags Batch normalization flags (@ref dnnl_normalization_flags_t).
2124 /// @returns #dnnl_success on success and a status describing the error
2125 ///     otherwise.
2126 dnnl_status_t DNNL_API dnnl_batch_normalization_backward_desc_init(
2127         dnnl_batch_normalization_desc_t *bnrm_desc, dnnl_prop_kind_t prop_kind,
2128         const dnnl_memory_desc_t *diff_data_desc,
2129         const dnnl_memory_desc_t *data_desc, float epsilon, unsigned flags);
2130 
2131 /// @} dnnl_api_batch_normalization
2132 
2133 /// @addtogroup dnnl_api_layer_normalization
2134 /// @{
2135 
2136 /// Initializes a descriptor for layer normalization forward propagation
2137 /// primitive.
2138 ///
2139 /// @note
2140 ///     In-place operation is supported: the dst can refer to the same memory
2141 ///     as the src.
2142 ///
2143 /// @param lnrm_desc Output descriptor for layer normalization primitive.
2144 /// @param prop_kind Propagation kind. Possible values are
2145 ///     #dnnl_forward_training and #dnnl_forward_inference.
2146 /// @param data_desc Source and destination memory descriptor.
2147 /// @param stat_desc Memory descriptor for mean and variance. If this
2148 ///     parameter is NULL, a zero memory descriptor, or a memory descriptor
2149 ///     with format_kind set to #dnnl_format_kind_undef, then the memory
2150 ///     descriptor for stats is derived from @p data_desc by removing the last
2151 ///     dimension.
2152 /// @param epsilon Layer normalization epsilon parameter.
2153 /// @param flags Layer normalization flags (@ref dnnl_normalization_flags_t).
2154 /// @returns #dnnl_success on success and a status describing the error
2155 ///     otherwise.
2156 dnnl_status_t DNNL_API dnnl_layer_normalization_forward_desc_init(
2157         dnnl_layer_normalization_desc_t *lnrm_desc, dnnl_prop_kind_t prop_kind,
2158         const dnnl_memory_desc_t *data_desc,
2159         const dnnl_memory_desc_t *stat_desc, float epsilon, unsigned flags);
2160 
2161 /// Initializes a descriptor for a layer normalization backward propagation
2162 /// primitive.
2163 ///
2164 /// @note
2165 ///     In-place operation is supported: the diff_dst can refer to the same
2166 ///     memory as the diff_src.
2167 ///
2168 /// @param lnrm_desc Output descriptor for layer normalization primitive.
2169 /// @param prop_kind Propagation kind. Possible values are
2170 ///     #dnnl_backward_data and #dnnl_backward (diffs for all parameters are
2171 ///     computed in this case).
2172 /// @param diff_data_desc Diff source and diff destination memory descriptor.
2173 /// @param data_desc Source memory descriptor.
2174 /// @param stat_desc Memory descriptor for mean and variance. If this
2175 ///     parameter is NULL, a zero memory descriptor, or a memory descriptor
2176 ///     with format_kind set to #dnnl_format_kind_undef, then the memory
2177 ///     descriptor for stats is derived from @p data_desc by removing the last
2178 ///     dimension.
2179 /// @param epsilon Layer normalization epsilon parameter.
2180 /// @param flags Layer normalization flags (@ref dnnl_normalization_flags_t).
2181 /// @returns #dnnl_success on success and a status describing the error
2182 ///     otherwise.
2183 dnnl_status_t DNNL_API dnnl_layer_normalization_backward_desc_init(
2184         dnnl_layer_normalization_desc_t *lnrm_desc, dnnl_prop_kind_t prop_kind,
2185         const dnnl_memory_desc_t *diff_data_desc,
2186         const dnnl_memory_desc_t *data_desc,
2187         const dnnl_memory_desc_t *stat_desc, float epsilon, unsigned flags);
2188 
2189 /// @} dnnl_api_layer_normalization
2190 
2191 /// @addtogroup dnnl_api_inner_product
2192 /// @{
2193 
2194 /// Initializes descriptor for inner product forward propagation.
2195 ///
2196 /// @note
2197 ///     Memory descriptors can be initialized with
2198 ///     #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
2199 ///
2200 /// @param ip_desc Output descriptor for inner product primitive.
2201 /// @param prop_kind Propagation kind. Possible values are
2202 ///     #dnnl_forward_training and #dnnl_forward_inference.
2203 /// @param src_desc Source memory descriptor.
2204 /// @param weights_desc Weights memory descriptor.
2205 /// @param bias_desc Bias memory descriptor. Passing NULL, a zero memory
2206 ///     descriptor, or a memory descriptor with format_kind set to
2207 ///     #dnnl_format_kind_undef disables the bias term.
2208 /// @param dst_desc Destination memory descriptor.
2209 /// @returns #dnnl_success on success and a status describing the error
2210 ///     otherwise.
2211 dnnl_status_t DNNL_API dnnl_inner_product_forward_desc_init(
2212         dnnl_inner_product_desc_t *ip_desc, dnnl_prop_kind_t prop_kind,
2213         const dnnl_memory_desc_t *src_desc,
2214         const dnnl_memory_desc_t *weights_desc,
2215         const dnnl_memory_desc_t *bias_desc,
2216         const dnnl_memory_desc_t *dst_desc);
2217 
2218 /// Initializes descriptor for inner product backward propagation.
2219 ///
2220 /// @note
2221 ///     Memory descriptors can be initialized with
2222 ///     #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
2223 ///
2224 /// @param ip_desc Output descriptor for inner product primitive.
2225 /// @param diff_src_desc Diff source memory descriptor.
2226 /// @param weights_desc Weights memory descriptor.
2227 /// @param diff_dst_desc Diff destination memory descriptor.
2228 /// @returns #dnnl_success on success and a status describing the error
2229 ///     otherwise.
2230 dnnl_status_t DNNL_API dnnl_inner_product_backward_data_desc_init(
2231         dnnl_inner_product_desc_t *ip_desc,
2232         const dnnl_memory_desc_t *diff_src_desc,
2233         const dnnl_memory_desc_t *weights_desc,
2234         const dnnl_memory_desc_t *diff_dst_desc);
2235 
2236 /// Initializes descriptor for inner product weights gradient primitive.
2237 ///
2238 /// @note
2239 ///     Memory descriptors can be initialized with
2240 ///     #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
2241 ///
2242 /// @param ip_desc Output descriptor for inner product primitive.
2243 /// @param src_desc Source memory descriptor.
2244 /// @param diff_weights_desc Diff weights memory descriptor.
2245 /// @param diff_bias_desc Diff bias memory descriptor. Passing NULL, a zero
2246 ///     memory descriptor, or a memory descriptor with format_kind set to
2247 ///     #dnnl_format_kind_undef disables the bias term.
2248 /// @param diff_dst_desc Diff destination memory descriptor.
2249 /// @returns #dnnl_success on success and a status describing the error
2250 ///     otherwise.
2251 dnnl_status_t DNNL_API dnnl_inner_product_backward_weights_desc_init(
2252         dnnl_inner_product_desc_t *ip_desc, const dnnl_memory_desc_t *src_desc,
2253         const dnnl_memory_desc_t *diff_weights_desc,
2254         const dnnl_memory_desc_t *diff_bias_desc,
2255         const dnnl_memory_desc_t *diff_dst_desc);
2256 
2257 /// @} dnnl_api_inner_product
2258 
2259 /// @addtogroup dnnl_api_attributes
2260 /// @{
2261 
2262 /// Set quantization scale and shift parameters for RNN data tensors.
2263 ///
2264 /// For performance reasons, the low-precision configuration of the RNN
2265 /// primitives expects input activations to have the unsigned 8-bit integer
2266 /// data type. The scale and shift parameters are used to quantize
2267 /// floating-point data to unsigned integer and must be passed to the RNN
2268 /// primitive using attributes.
2269 ///
2270 /// The quantization formula is `scale * data + shift`.
2271 ///
2272 /// @note
2273 ///     Quantization scale and shift are common for src_layer, src_iter,
2274 ///     dst_iter, and dst_layer.
2275 ///
2276 /// Example usage:
2277 /// @code
2278 ///     // RNN parameters
2279 ///     int l = 2, t = 2, mb = 32, sic = 32, slc = 32, dic = 32, dlc = 32;
2280 ///     // Activations quantization parameters
2281 ///     float scale = 63.f, shift = 64.f;
2282 ///
2283 ///     dnnl_primitive_attr_t rnn_attr;
2284 ///     // Create default attributes
2285 ///     dnnl_primitive_attr_create(&rnn_attr);
2286 ///
2287 ///     // Set scale and shift for int8 quantization of activation
2288 ///     dnnl_primitive_attr_set_rnn_data_qparams(rnn_attr, scale, shift);
2289 ///
2290 ///     // Create and configure rnn op_desc
2291 ///     dnnl_rnn_desc_t rnn_d;
2292 ///     dnnl_primitive_desc_t rnn_pd;
2293 ///     dnnl_primitive_desc_create(&rnn_pd, &rnn_d, attr, engine, NULL);
2294 /// @endcode
2295 ///
2296 /// @param attr Primitive attributes.
2297 /// @param scale The value to scale the data by.
2298 /// @param shift The value to shift the data by.
2299 /// @returns #dnnl_success on success and a status describing the error
2300 ///     otherwise.
2301 dnnl_status_t DNNL_API dnnl_primitive_attr_set_rnn_data_qparams(
2302         dnnl_primitive_attr_t attr, const float scale, const float shift);
2303 
2304 /// Returns the quantization scale and shift parameters for RNN data tensors.
2305 ///
2306 /// @note
2307 ///     Quantization scale and shift are common for src_layer, src_iter,
2308 ///     dst_iter, and dst_layer.
2309 ///
2310 /// @param attr Primitive attributes.
2311 /// @param scale The value to scale the data by.
2312 /// @param shift The value to shift the data by.
2313 /// @returns #dnnl_success on success and a status describing the error
2314 ///     otherwise.
2315 dnnl_status_t DNNL_API dnnl_primitive_attr_get_rnn_data_qparams(
2316         const_dnnl_primitive_attr_t attr, float *scale, float *shift);
2317 
2318 /// Sets quantization scaling factors for RNN weights tensors. The
2319 /// low-precision configuration of the RNN primitives expects input weights to
2320 /// use the signed 8-bit integer data type. The scaling factors are used to
2321 /// quantize floating-point data to signed integer and must be passed to RNN
2322 /// primitives using attributes.
2323 ///
2324 /// @note
2325 ///     The dimension order is always native and does not depend on the actual
2326 ///     layout used. For example, five-dimensional weights always have (l, d,
2327 ///     i, g, o) logical dimension ordering.
2328 ///
2329 /// @note
2330 ///     Quantization scales are common for weights_layer and weights_iteration
2331 ///
2332 /// @param attr Primitive attributes.
2333 /// @param count Number of elements in the @p scales array.
2334 /// @param mask Scaling factors correspondence mask that defines the
2335 ///     correspondence between the output tensor dimensions and the @p
2336 ///     scales vector. The set i-th bit indicates that a dedicated scaling
2337 ///     factor should be used for each index along that dimension. Set the
2338 ///     mask to 0 to use a common scaling factor for the whole output
2339 ///     tensor.
2340 /// @param scales Array of output scaling factors that must contain @p count
2341 ///     values and the following equality must hold:
2342 ///     \f[count = \prod\limits_{d \in mask} weights.dims[d].\f]
2343 ///     Violations can only be detected when the attributes are used to create
2344 ///     a primitive descriptor.
2345 /// @returns #dnnl_success on success and a status describing the error
2346 ///     otherwise.
2347 dnnl_status_t DNNL_API dnnl_primitive_attr_set_rnn_weights_qparams(
2348         dnnl_primitive_attr_t attr, dnnl_dim_t count, int mask,
2349         const float *scales);
2350 
2351 /// Returns the quantization scaling factors for RNN weights tensors.
2352 ///
2353 /// @param attr Primitive attributes.
2354 /// @param count Number of elements in the @p scales array.
2355 /// @param mask Scaling factors correspondence mask that defines the
2356 ///     correspondence between the output tensor dimensions and the @p
2357 ///     scales vector. The set i-th bit indicates that a dedicated scaling
2358 ///     factor should be used for each index along that dimension. Set the
2359 ///     mask to 0 to use a common scaling factor for the whole output
2360 ///     tensor.
2361 /// @param scales Array of output scaling factors that contain @p count
2362 ///     values and the following equality must hold:
2363 ///     \f[count = \prod\limits_{d \in mask} weights.dims[d].\f]
2364 /// @returns #dnnl_success on success and a status describing the error
2365 ///     otherwise.
2366 dnnl_status_t DNNL_API dnnl_primitive_attr_get_rnn_weights_qparams(
2367         const_dnnl_primitive_attr_t attr, dnnl_dim_t *count, int *mask,
2368         const float **scales);
2369 
2370 /// Sets quantization scaling factors for RNN projection weights tensors. The
2371 /// low-precision configuration of the RNN primitives expects input weights to
2372 /// use the signed 8-bit integer data type. The scaling factors are used to
2373 /// quantize floating-point data to signed integer and must be passed to RNN
2374 /// primitives using attributes.
2375 ///
2376 /// @note
2377 ///     The dimension order is always native and does not depend on the actual
2378 ///     layout used. For example, five-dimensional weights always have (l, d,
2379 ///     i, g, o) logical dimension ordering.
2380 ///
2381 /// @param attr Primitive attributes.
2382 /// @param count Number of elements in the @p scales array.
2383 /// @param mask Scaling factors correspondence mask that defines the
2384 ///     correspondence between the output tensor dimensions and the @p
2385 ///     scales vector. The set i-th bit indicates that a dedicated scaling
2386 ///     factor should be used for each index along that dimension. Set the
2387 ///     mask to 0 to use a common scaling factor for the whole output
2388 ///     tensor.
2389 /// @param scales Array of output scaling factors that must contain @p count
2390 ///     values and the following equality must hold:
2391 ///     \f[count = \prod\limits_{d \in mask} weights.dims[d].\f]
2392 ///     Violations can only be detected when the attributes are used to create
2393 ///     a primitive descriptor.
2394 /// @returns #dnnl_success on success and a status describing the error
2395 ///     otherwise.
2396 dnnl_status_t DNNL_API dnnl_primitive_attr_set_rnn_weights_projection_qparams(
2397         dnnl_primitive_attr_t attr, dnnl_dim_t count, int mask,
2398         const float *scales);
2399 
2400 /// Returns the quantization scaling factors for RNN projection weights tensors.
2401 ///
2402 /// @param attr Primitive attributes.
2403 /// @param count Number of elements in the @p scales array.
2404 /// @param mask Scaling factors correspondence mask that defines the
2405 ///     correspondence between the output tensor dimensions and the @p
2406 ///     scales vector. The set i-th bit indicates that a dedicated scaling
2407 ///     factor should be used for each index along that dimension. Set the
2408 ///     mask to 0 to use a common scaling factor for the whole output
2409 ///     tensor.
2410 /// @param scales Array of output scaling factors that contain @p count
2411 ///     values and the following equality must hold:
2412 ///     \f[count = \prod\limits_{d \in mask} weights.dims[d].\f]
2413 /// @returns #dnnl_success on success and a status describing the error
2414 ///     otherwise.
2415 dnnl_status_t DNNL_API dnnl_primitive_attr_get_rnn_weights_projection_qparams(
2416         const_dnnl_primitive_attr_t attr, dnnl_dim_t *count, int *mask,
2417         const float **scales);
2418 
2419 /// @} dnnl_api_attributes
2420 
2421 /// @addtogroup dnnl_api_rnn
2422 /// @{
2423 
2424 /// Initializes a descriptor for vanilla RNN forward propagation primitive.
2425 ///
2426 /// The following arguments may either be @c NULL or point to a zero memory
2427 /// descriptor:
2428 /// - @p src_iter_desc,
2429 /// - @p bias_desc,
2430 /// - @p dst_iter_desc.
2431 ///
2432 /// This would then indicate that the RNN forward propagation primitive should
2433 /// not use them and should default to zero values instead.
2434 ///
2435 /// @note
2436 ///     All memory descriptors can be initialized with
2437 ///     #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
2438 ///
2439 /// @param rnn_desc Output descriptor for vanilla RNN primitive.
2440 /// @param prop_kind Propagation kind. Possible values are
2441 ///     #dnnl_forward_training and #dnnl_forward_inference.
2442 /// @param activation Activation kind. Possible values are #dnnl_eltwise_relu,
2443 ///     #dnnl_eltwise_tanh or #dnnl_eltwise_logistic.
2444 /// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
2445 ///     info.
2446 /// @param src_layer_desc Memory descriptor for the input vector.
2447 /// @param src_iter_desc Memory descriptor for the input recurrent hidden
2448 ///     state vector.
2449 /// @param weights_layer_desc Memory descriptor for the weights applied to the
2450 ///     layer input.
2451 /// @param weights_iter_desc Memory descriptor for the weights applied to the
2452 ///     recurrent input.
2453 /// @param bias_desc Bias memory descriptor.
2454 /// @param dst_layer_desc Memory descriptor for the output vector.
2455 /// @param dst_iter_desc Memory descriptor for the output recurrent hidden
2456 ///     state vector.
2457 /// @param flags Unused.
2458 /// @param alpha Negative slope if activation is #dnnl_eltwise_relu.
2459 /// @param beta Unused.
2460 /// @returns #dnnl_success on success and a status describing the error
2461 ///     otherwise.
2462 dnnl_status_t DNNL_API dnnl_vanilla_rnn_forward_desc_init(
2463         dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind,
2464         const dnnl_alg_kind_t activation, const dnnl_rnn_direction_t direction,
2465         const dnnl_memory_desc_t *src_layer_desc,
2466         const dnnl_memory_desc_t *src_iter_desc,
2467         const dnnl_memory_desc_t *weights_layer_desc,
2468         const dnnl_memory_desc_t *weights_iter_desc,
2469         const dnnl_memory_desc_t *bias_desc,
2470         const dnnl_memory_desc_t *dst_layer_desc,
2471         const dnnl_memory_desc_t *dst_iter_desc, unsigned flags, float alpha,
2472         float beta);
2473 
2474 /// Initializes a descriptor for vanilla RNN backward propagation primitive.
2475 ///
2476 /// The following arguments may either be @c NULL or point to a zero memory
2477 /// descriptor:
2478 /// - @p src_iter_desc together with @p diff_src_iter_desc,
2479 /// - @p bias_desc together with @p diff_bias_desc,
2480 /// - @p dst_iter_desc together with @p diff_dst_iter_desc.
2481 ///
2482 /// This would then indicate that the RNN backward propagation primitive should
2483 /// not use the respective data and should use zero values instead.
2484 ///
2485 /// @note
2486 ///     All memory descriptors can be initialized with
2487 ///     #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
2488 ///
2489 /// @param rnn_desc Output descriptor for vanilla RNN primitive.
2490 /// @param prop_kind Propagation kind. Must be #dnnl_backward.
2491 /// @param activation Activation kind. Possible values are #dnnl_eltwise_relu,
2492 ///     #dnnl_eltwise_tanh or #dnnl_eltwise_logistic.
2493 /// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
2494 ///     info.
2495 /// @param src_layer_desc Memory descriptor for the input vector.
2496 /// @param src_iter_desc Memory descriptor for the input recurrent hidden
2497 ///     state vector.
2498 /// @param weights_layer_desc Memory descriptor for the weights applied to the
2499 ///     layer input.
2500 /// @param weights_iter_desc Memory descriptor for the weights applied to the
2501 ///     recurrent input.
2502 /// @param bias_desc Bias memory descriptor.
2503 /// @param dst_layer_desc Memory descriptor for the output vector.
2504 /// @param dst_iter_desc Memory descriptor for the output recurrent hidden
2505 ///     state vector.
2506 /// @param diff_src_layer_desc Memory descriptor for the diff of input vector.
2507 /// @param diff_src_iter_desc Memory descriptor for the diff of input recurrent
2508 ///     hidden state vector.
2509 /// @param diff_weights_layer_desc Memory descriptor for the diff of weights
2510 ///     applied to the layer input.
2511 /// @param diff_weights_iter_desc Memory descriptor for the diff of weights
2512 ///     applied to the recurrent input.
2513 /// @param diff_bias_desc Diff bias memory descriptor.
2514 /// @param diff_dst_layer_desc Memory descriptor for the diff of output
2515 ///     vector.
2516 /// @param diff_dst_iter_desc Memory descriptor for the diff of output
2517 ///     recurrent hidden state vector.
2518 /// @param flags Unused.
2519 /// @param alpha Negative slope if activation is #dnnl_eltwise_relu.
2520 /// @param beta Unused.
2521 /// @returns #dnnl_success on success and a status describing the error
2522 ///     otherwise.
2523 dnnl_status_t DNNL_API dnnl_vanilla_rnn_backward_desc_init(
2524         dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind,
2525         const dnnl_alg_kind_t activation, const dnnl_rnn_direction_t direction,
2526         const dnnl_memory_desc_t *src_layer_desc,
2527         const dnnl_memory_desc_t *src_iter_desc,
2528         const dnnl_memory_desc_t *weights_layer_desc,
2529         const dnnl_memory_desc_t *weights_iter_desc,
2530         const dnnl_memory_desc_t *bias_desc,
2531         const dnnl_memory_desc_t *dst_layer_desc,
2532         const dnnl_memory_desc_t *dst_iter_desc,
2533         const dnnl_memory_desc_t *diff_src_layer_desc,
2534         const dnnl_memory_desc_t *diff_src_iter_desc,
2535         const dnnl_memory_desc_t *diff_weights_layer_desc,
2536         const dnnl_memory_desc_t *diff_weights_iter_desc,
2537         const dnnl_memory_desc_t *diff_bias_desc,
2538         const dnnl_memory_desc_t *diff_dst_layer_desc,
2539         const dnnl_memory_desc_t *diff_dst_iter_desc, unsigned flags,
2540         float alpha, float beta);
2541 
2542 /// Initializes a descriptor for LSTM forward propagation primitive.
2543 ///
2544 /// The following arguments may either be @c NULL or point to a zero memory
2545 /// descriptor:
2546 /// - @p src_iter_desc together with @p src_iter_c_desc,
2547 /// - @p bias_desc,
2548 /// - @p dst_iter_desc together with @p dst_iter_c_desc.
2549 ///
2550 /// This would then indicate that the LSTM forward propagation primitive should
2551 /// not use them and should default to zero values instead.
2552 ///
2553 /// @note
2554 ///     All memory descriptors can be initialized with
2555 ///     #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
2556 ///
2557 /// @sa dnnl_lstm_forward_desc_init_v2 to initialize forward LSTM with and
2558 ///     without peephole
2559 /// @sa dnnl_lstm_forward_desc_init_v3 to initialize forward LSTM with and
2560 ///     without peephole / recurrent projection layer
2561 ///
2562 /// @param rnn_desc Output descriptor for LSTM primitive.
2563 /// @param prop_kind Propagation kind. Possible values are
2564 ///     #dnnl_forward_training and #dnnl_forward_inference.
2565 /// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
2566 ///     info.
2567 /// @param src_layer_desc Memory descriptor for the input vector.
2568 /// @param src_iter_desc Memory descriptor for the input recurrent hidden
2569 ///     state vector.
2570 /// @param src_iter_c_desc Memory descriptor for the input recurrent cell
2571 ///     state vector.
2572 /// @param weights_layer_desc Memory descriptor for the weights applied to the
2573 ///     layer input.
2574 /// @param weights_iter_desc Memory descriptor for the weights applied to the
2575 ///     recurrent input.
2576 /// @param bias_desc Bias memory descriptor.
2577 /// @param dst_layer_desc Memory descriptor for the output vector.
2578 /// @param dst_iter_desc Memory descriptor for the output recurrent hidden
2579 ///     state vector.
2580 /// @param dst_iter_c_desc Memory descriptor for the output recurrent cell
2581 ///     state vector.
2582 /// @param flags Unused.
2583 /// @returns #dnnl_success on success and a status describing the error
2584 ///     otherwise.
2585 dnnl_status_t DNNL_API dnnl_lstm_forward_desc_init(dnnl_rnn_desc_t *rnn_desc,
2586         dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
2587         const dnnl_memory_desc_t *src_layer_desc,
2588         const dnnl_memory_desc_t *src_iter_desc,
2589         const dnnl_memory_desc_t *src_iter_c_desc,
2590         const dnnl_memory_desc_t *weights_layer_desc,
2591         const dnnl_memory_desc_t *weights_iter_desc,
2592         const dnnl_memory_desc_t *bias_desc,
2593         const dnnl_memory_desc_t *dst_layer_desc,
2594         const dnnl_memory_desc_t *dst_iter_desc,
2595         const dnnl_memory_desc_t *dst_iter_c_desc, unsigned flags);
2596 
2597 /// Initializes a descriptor for an LSTM (with or without peephole) forward
2598 /// propagation primitive.
2599 ///
2600 /// The following arguments may either be @c NULL or point to a zero memory
2601 /// descriptor:
2602 /// - @p src_iter_desc together with @p src_iter_c_desc,
2603 /// - @p weights_peephole_desc,
2604 /// - @p bias_desc,
2605 /// - @p dst_iter_desc together with @p dst_iter_c_desc.
2606 ///
2607 /// This would then indicate that the LSTM forward propagation primitive should
2608 /// not use them and should default to zero values instead.
2609 ///
2610 /// @note
2611 ///     All memory descriptors can be initialized with #dnnl_format_tag_any or
2612 ///     with format_kind set to #dnnl_format_kind_any.
2613 ///
2614 /// @sa dnnl_lstm_forward_desc_init_v3 to initialize forward LSTM with and
2615 ///     without peephole / recurrent projection layer
2616 ///
2617 /// @param rnn_desc Output descriptor for LSTM primitive.
2618 /// @param prop_kind Propagation kind. Possible values are
2619 ///     #dnnl_forward_training and #dnnl_forward_inference.
2620 /// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
2621 ///     info.
2622 /// @param src_layer_desc Memory descriptor for the input vector.
2623 /// @param src_iter_desc Memory descriptor for the input recurrent hidden
2624 ///     state vector.
2625 /// @param src_iter_c_desc Memory descriptor for the input recurrent cell
2626 ///     state vector.
2627 /// @param weights_layer_desc Memory descriptor for the weights applied to the
2628 ///     layer input.
2629 /// @param weights_iter_desc Memory descriptor for the weights applied to the
2630 ///     recurrent input.
2631 /// @param weights_peephole_desc Memory descriptor for the weights applied to
2632 ///     the cell states (according to the Peephole LSTM formula).
2633 /// @param bias_desc Bias memory descriptor.
2634 /// @param dst_layer_desc Memory descriptor for the output vector.
2635 /// @param dst_iter_desc Memory descriptor for the output recurrent hidden
2636 ///     state vector.
2637 /// @param dst_iter_c_desc Memory descriptor for the output recurrent cell
2638 ///     state vector.
2639 /// @param flags Unused.
2640 /// @returns #dnnl_success on success and a status describing the error
2641 ///     otherwise.
2642 dnnl_status_t DNNL_API dnnl_lstm_forward_desc_init_v2(dnnl_rnn_desc_t *rnn_desc,
2643         dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
2644         const dnnl_memory_desc_t *src_layer_desc,
2645         const dnnl_memory_desc_t *src_iter_desc,
2646         const dnnl_memory_desc_t *src_iter_c_desc,
2647         const dnnl_memory_desc_t *weights_layer_desc,
2648         const dnnl_memory_desc_t *weights_iter_desc,
2649         const dnnl_memory_desc_t *weights_peephole_desc,
2650         const dnnl_memory_desc_t *bias_desc,
2651         const dnnl_memory_desc_t *dst_layer_desc,
2652         const dnnl_memory_desc_t *dst_iter_desc,
2653         const dnnl_memory_desc_t *dst_iter_c_desc, unsigned flags);
2654 
2655 /// Initializes a descriptor for an LSTM (with or without peephole and with
2656 /// or without recurrent projection layer) forward propagation primitive.
2657 ///
2658 /// The following arguments may either be @c NULL or point to a zero memory
2659 /// descriptor:
2660 /// - @p src_iter_desc together with @p src_iter_c_desc,
2661 /// - @p weights_peephole_desc,
2662 /// - @p bias_desc,
2663 /// - @p dst_iter_desc together with @p dst_iter_c_desc.
2664 ///
2665 /// This would then indicate that the LSTM forward propagation primitive should
2666 /// not use them and should default to zero values instead.
2667 ///
2668 /// The @p weights_projection_desc could either be @c NULL or point to a zero
2669 /// memory descriptor. This would then indicate that the LSTM doesn't have
2670 /// recurrent projection layer.
2671 ///
2672 /// @note
2673 ///     All memory descriptors can be initialized with #dnnl_format_tag_any or
2674 ///     with format_kind set to #dnnl_format_kind_any.
2675 ///
2676 /// @param rnn_desc Output descriptor for LSTM primitive.
2677 /// @param prop_kind Propagation kind. Possible values are
2678 ///     #dnnl_forward_training and #dnnl_forward_inference.
2679 /// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
2680 ///     info.
2681 /// @param src_layer_desc Memory descriptor for the input vector.
2682 /// @param src_iter_desc Memory descriptor for the input recurrent hidden
2683 ///     state vector.
2684 /// @param src_iter_c_desc Memory descriptor for the input recurrent cell
2685 ///     state vector.
2686 /// @param weights_layer_desc Memory descriptor for the weights applied to the
2687 ///     layer input.
2688 /// @param weights_iter_desc Memory descriptor for the weights applied to the
2689 ///     recurrent input.
2690 /// @param weights_peephole_desc Memory descriptor for the weights applied to
2691 ///     the cell states (according to the Peephole LSTM formula).
2692 /// @param weights_projection_desc Memory descriptor for the weights applied to
2693 ///     the hidden states to get the recurrent projection (according to the
2694 ///     Projection LSTM formula).
2695 /// @param bias_desc Bias memory descriptor.
2696 /// @param dst_layer_desc Memory descriptor for the output vector.
2697 /// @param dst_iter_desc Memory descriptor for the output recurrent hidden
2698 ///     state vector.
2699 /// @param dst_iter_c_desc Memory descriptor for the output recurrent cell
2700 ///     state vector.
2701 /// @param flags Unused.
2702 /// @returns #dnnl_success on success and a status describing the error
2703 ///     otherwise.
2704 dnnl_status_t DNNL_API dnnl_lstm_forward_desc_init_v3(dnnl_rnn_desc_t *rnn_desc,
2705         dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
2706         const dnnl_memory_desc_t *src_layer_desc,
2707         const dnnl_memory_desc_t *src_iter_desc,
2708         const dnnl_memory_desc_t *src_iter_c_desc,
2709         const dnnl_memory_desc_t *weights_layer_desc,
2710         const dnnl_memory_desc_t *weights_iter_desc,
2711         const dnnl_memory_desc_t *weights_peephole_desc,
2712         const dnnl_memory_desc_t *weights_projection_desc,
2713         const dnnl_memory_desc_t *bias_desc,
2714         const dnnl_memory_desc_t *dst_layer_desc,
2715         const dnnl_memory_desc_t *dst_iter_desc,
2716         const dnnl_memory_desc_t *dst_iter_c_desc, unsigned flags);
2717 
2718 /// Initializes a descriptor for an LSTM backward propagation primitive.
2719 ///
2720 /// The following arguments may either be @c NULL or point to a zero memory
2721 /// descriptor:
2722 /// - @p src_iter_desc together with @p src_iter_c_desc, @p diff_src_iter_desc,
2723 ///   and @p diff_src_iter_c_desc,
2724 /// - @p bias_desc together with @p diff_bias_desc,
2725 /// - @p dst_iter_desc together with @p dst_iter_c_desc, @p diff_dst_iter_desc,
2726 ///   and @p diff_dst_iter_c_desc.
2727 ///
2728 /// This would then indicate that the LSTM backward propagation primitive
2729 /// should not use them and should default to zero values instead.
2730 ///
2731 /// @note
2732 ///     All memory descriptors can be initialized with
2733 ///     #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
2734 ///
2735 /// @sa dnnl_lstm_backward_desc_init_v2 to initialize backward LSTM with and
2736 ///     without peephole
2737 /// @sa dnnl_lstm_backward_desc_init_v3 to initialize backward LSTM with and
2738 ///     without peephole / recurrent projection layer
2739 ///
2740 /// @param rnn_desc Output descriptor for LSTM primitive.
2741 /// @param prop_kind Propagation kind. Must be #dnnl_backward.
2742 /// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
2743 ///     info.
2744 /// @param src_layer_desc Memory descriptor for the input vector.
2745 /// @param src_iter_desc Memory descriptor for the input recurrent hidden
2746 ///     state vector.
2747 /// @param src_iter_c_desc Memory descriptor for the input recurrent cell
2748 ///     state vector.
2749 /// @param weights_layer_desc Memory descriptor for the weights applied to the
2750 ///     layer input.
2751 /// @param weights_iter_desc Memory descriptor for the weights applied to the
2752 ///     recurrent input.
2753 /// @param bias_desc Bias memory descriptor.
2754 /// @param dst_layer_desc Memory descriptor for the output vector.
2755 /// @param dst_iter_desc Memory descriptor for the output recurrent hidden
2756 ///     state vector.
2757 /// @param dst_iter_c_desc Memory descriptor for the output recurrent cell
2758 ///     state vector.
2759 /// @param diff_src_layer_desc Memory descriptor for the diff of input vector.
2760 /// @param diff_src_iter_desc Memory descriptor for the diff of input recurrent
2761 ///     hidden state vector.
2762 /// @param diff_src_iter_c_desc Memory descriptor for the diff of input
2763 /// recurrent cell state vector.
2764 /// @param diff_weights_layer_desc Memory descriptor for the diff of weights
2765 ///     applied to the layer input.
2766 /// @param diff_weights_iter_desc Memory descriptor for the diff of weights
2767 ///     applied to the recurrent input.
2768 /// @param diff_bias_desc Diff bias memory descriptor.
2769 /// @param diff_dst_layer_desc Memory descriptor for the diff of output
2770 ///     vector.
2771 /// @param diff_dst_iter_desc Memory descriptor for the diff of output
2772 ///     recurrent hidden state vector.
2773 /// @param diff_dst_iter_c_desc Memory descriptor for the diff of output
2774 ///     recurrent cell state vector.
2775 /// @param flags Unused.
2776 /// @returns #dnnl_success on success and a status describing the error
2777 ///     otherwise.
2778 dnnl_status_t DNNL_API dnnl_lstm_backward_desc_init(dnnl_rnn_desc_t *rnn_desc,
2779         dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
2780         const dnnl_memory_desc_t *src_layer_desc,
2781         const dnnl_memory_desc_t *src_iter_desc,
2782         const dnnl_memory_desc_t *src_iter_c_desc,
2783         const dnnl_memory_desc_t *weights_layer_desc,
2784         const dnnl_memory_desc_t *weights_iter_desc,
2785         const dnnl_memory_desc_t *bias_desc,
2786         const dnnl_memory_desc_t *dst_layer_desc,
2787         const dnnl_memory_desc_t *dst_iter_desc,
2788         const dnnl_memory_desc_t *dst_iter_c_desc,
2789         const dnnl_memory_desc_t *diff_src_layer_desc,
2790         const dnnl_memory_desc_t *diff_src_iter_desc,
2791         const dnnl_memory_desc_t *diff_src_iter_c_desc,
2792         const dnnl_memory_desc_t *diff_weights_layer_desc,
2793         const dnnl_memory_desc_t *diff_weights_iter_desc,
2794         const dnnl_memory_desc_t *diff_bias_desc,
2795         const dnnl_memory_desc_t *diff_dst_layer_desc,
2796         const dnnl_memory_desc_t *diff_dst_iter_desc,
2797         const dnnl_memory_desc_t *diff_dst_iter_c_desc, unsigned flags);
2798 
2799 /// Initializes a descriptor for an LSTM (with or without peephole) backward
2800 /// propagation primitive.
2801 ///
2802 /// The following arguments may either be @c NULL or point to a zero memory
2803 /// descriptor:
2804 /// - @p src_iter_desc together with @p src_iter_c_desc, @p diff_src_iter_desc,
2805 ///   and @p diff_src_iter_c_desc,
2806 /// - @p weights_peephole_desc together with @p diff_weights_peephole_desc,
2807 /// - @p bias_desc together with @p diff_bias_desc,
2808 /// - @p dst_iter_desc together with @p dst_iter_c_desc, @p diff_dst_iter_desc,
2809 ///   and @p diff_dst_iter_c_desc.
2810 ///
2811 /// This would then indicate that the LSTM backward propagation primitive
2812 /// should not use them and should default to zero values instead.
2813 ///
2814 /// @note
2815 ///     All memory descriptors can be initialized with #dnnl_format_tag_any or
2816 ///     with format_kind set to #dnnl_format_kind_any.
2817 ///
2818 /// @sa dnnl_lstm_backward_desc_init_v3 to initialize backward LSTM with and
2819 ///     without peephole / recurrent projection layer
2820 ///
2821 /// @param rnn_desc Output descriptor for LSTM primitive.
2822 /// @param prop_kind Propagation kind. Must be #dnnl_backward.
2823 /// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
2824 ///     info.
2825 /// @param src_layer_desc Memory descriptor for the input vector.
2826 /// @param src_iter_desc Memory descriptor for the input recurrent hidden
2827 ///     state vector.
2828 /// @param src_iter_c_desc Memory descriptor for the input recurrent cell
2829 ///     state vector.
2830 /// @param weights_layer_desc Memory descriptor for the weights applied to the
2831 ///     layer input.
2832 /// @param weights_iter_desc Memory descriptor for the weights applied to the
2833 ///     recurrent input.
2834 /// @param weights_peephole_desc Memory descriptor for the weights applied to
2835 ///     the cell states (according to the Peephole LSTM formula).
2836 /// @param bias_desc Bias memory descriptor.
2837 /// @param dst_layer_desc Memory descriptor for the output vector.
2838 /// @param dst_iter_desc Memory descriptor for the output recurrent hidden
2839 ///     state vector.
2840 /// @param dst_iter_c_desc Memory descriptor for the output recurrent cell
2841 ///     state vector.
2842 /// @param diff_src_layer_desc Memory descriptor for the diff of input vector.
2843 /// @param diff_src_iter_desc Memory descriptor for the diff of input recurrent
2844 ///     hidden state vector.
2845 /// @param diff_src_iter_c_desc Memory descriptor for the diff of input
2846 /// recurrent cell state vector.
2847 /// @param diff_weights_layer_desc Memory descriptor for the diff of weights
2848 ///     applied to the layer input.
2849 /// @param diff_weights_iter_desc Memory descriptor for the diff of weights
2850 ///     applied to the recurrent input.
2851 /// @param diff_weights_peephole_desc Memory descriptor for the diff of weights
2852 ///     applied to the cell states (according to the Peephole LSTM formula).
2853 /// @param diff_bias_desc Diff bias memory descriptor.
2854 /// @param diff_dst_layer_desc Memory descriptor for the diff of output
2855 ///     vector.
2856 /// @param diff_dst_iter_desc Memory descriptor for the diff of output
2857 ///     recurrent hidden state vector.
2858 /// @param diff_dst_iter_c_desc Memory descriptor for the diff of output
2859 ///     recurrent cell state vector.
2860 /// @param flags Unused.
2861 /// @returns #dnnl_success on success and a status describing the error
2862 ///     otherwise.
2863 dnnl_status_t DNNL_API dnnl_lstm_backward_desc_init_v2(
2864         dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind,
2865         dnnl_rnn_direction_t direction,
2866         const dnnl_memory_desc_t *src_layer_desc,
2867         const dnnl_memory_desc_t *src_iter_desc,
2868         const dnnl_memory_desc_t *src_iter_c_desc,
2869         const dnnl_memory_desc_t *weights_layer_desc,
2870         const dnnl_memory_desc_t *weights_iter_desc,
2871         const dnnl_memory_desc_t *weights_peephole_desc,
2872         const dnnl_memory_desc_t *bias_desc,
2873         const dnnl_memory_desc_t *dst_layer_desc,
2874         const dnnl_memory_desc_t *dst_iter_desc,
2875         const dnnl_memory_desc_t *dst_iter_c_desc,
2876         const dnnl_memory_desc_t *diff_src_layer_desc,
2877         const dnnl_memory_desc_t *diff_src_iter_desc,
2878         const dnnl_memory_desc_t *diff_src_iter_c_desc,
2879         const dnnl_memory_desc_t *diff_weights_layer_desc,
2880         const dnnl_memory_desc_t *diff_weights_iter_desc,
2881         const dnnl_memory_desc_t *diff_weights_peephole_desc,
2882         const dnnl_memory_desc_t *diff_bias_desc,
2883         const dnnl_memory_desc_t *diff_dst_layer_desc,
2884         const dnnl_memory_desc_t *diff_dst_iter_desc,
2885         const dnnl_memory_desc_t *diff_dst_iter_c_desc, unsigned flags);
2886 
2887 /// Initializes a descriptor for an LSTM (with or without peephole and with or
2888 /// with out recurrent projection layer) backward propagation primitive.
2889 ///
2890 /// The following arguments may either be @c NULL or point to a zero memory
2891 /// descriptor:
2892 /// - @p src_iter_desc together with @p src_iter_c_desc, @p diff_src_iter_desc,
2893 ///   and @p diff_src_iter_c_desc,
2894 /// - @p weights_peephole_desc together with @p diff_weights_peephole_desc,
2895 /// - @p bias_desc together with @p diff_bias_desc,
2896 /// - @p dst_iter_desc together with @p dst_iter_c_desc, @p diff_dst_iter_desc,
2897 ///   and @p diff_dst_iter_c_desc.
2898 ///
2899 /// This would then indicate that the LSTM backward propagation primitive
2900 /// should not use them and should default to zero values instead.
2901 ///
2902 /// The @p weights_projection_desc together with @p
2903 /// diff_weights_projection_desc could either be @c NULL or point to a zero
2904 /// memory descriptor. This would then indicate that the LSTM doesn't have
2905 /// recurrent projection layer.
2906 ///
2907 /// @note
2908 ///     All memory descriptors can be initialized with #dnnl_format_tag_any or
2909 ///     with format_kind set to #dnnl_format_kind_any.
2910 ///
2911 /// @param rnn_desc Output descriptor for LSTM primitive.
2912 /// @param prop_kind Propagation kind. Must be #dnnl_backward.
2913 /// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
2914 ///     info.
2915 /// @param src_layer_desc Memory descriptor for the input vector.
2916 /// @param src_iter_desc Memory descriptor for the input recurrent hidden
2917 ///     state vector.
2918 /// @param src_iter_c_desc Memory descriptor for the input recurrent cell
2919 ///     state vector.
2920 /// @param weights_layer_desc Memory descriptor for the weights applied to the
2921 ///     layer input.
2922 /// @param weights_iter_desc Memory descriptor for the weights applied to the
2923 ///     recurrent input.
2924 /// @param weights_peephole_desc Memory descriptor for the weights applied to
2925 ///     the cell states (according to the Peephole LSTM formula).
2926 /// @param weights_projection_desc Memory descriptor for the weights applied to
2927 ///     the hidden states to get the recurrent projection (according to the
2928 ///     Projection LSTM formula).
2929 /// @param bias_desc Bias memory descriptor.
2930 /// @param dst_layer_desc Memory descriptor for the output vector.
2931 /// @param dst_iter_desc Memory descriptor for the output recurrent hidden
2932 ///     state vector.
2933 /// @param dst_iter_c_desc Memory descriptor for the output recurrent cell
2934 ///     state vector.
2935 /// @param diff_src_layer_desc Memory descriptor for the diff of input vector.
2936 /// @param diff_src_iter_desc Memory descriptor for the diff of input recurrent
2937 ///     hidden state vector.
2938 /// @param diff_src_iter_c_desc Memory descriptor for the diff of input
2939 /// recurrent cell state vector.
2940 /// @param diff_weights_layer_desc Memory descriptor for the diff of weights
2941 ///     applied to the layer input.
2942 /// @param diff_weights_iter_desc Memory descriptor for the diff of weights
2943 ///     applied to the recurrent input.
2944 /// @param diff_weights_peephole_desc Memory descriptor for the diff of weights
2945 ///     applied to the cell states (according to the Peephole LSTM formula).
2946 /// @param diff_weights_projection_desc Memory descriptor for the diff of
2947 ///     weights applied to the hidden states to get the recurrent projection
2948 ///     (according to the Projection LSTM formula).
2949 /// @param diff_bias_desc Diff bias memory descriptor.
2950 /// @param diff_dst_layer_desc Memory descriptor for the diff of output
2951 ///     vector.
2952 /// @param diff_dst_iter_desc Memory descriptor for the diff of output
2953 ///     recurrent hidden state vector.
2954 /// @param diff_dst_iter_c_desc Memory descriptor for the diff of output
2955 ///     recurrent cell state vector.
2956 /// @param flags Unused.
2957 /// @returns #dnnl_success on success and a status describing the error
2958 ///     otherwise.
2959 dnnl_status_t DNNL_API dnnl_lstm_backward_desc_init_v3(
2960         dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind,
2961         dnnl_rnn_direction_t direction,
2962         const dnnl_memory_desc_t *src_layer_desc,
2963         const dnnl_memory_desc_t *src_iter_desc,
2964         const dnnl_memory_desc_t *src_iter_c_desc,
2965         const dnnl_memory_desc_t *weights_layer_desc,
2966         const dnnl_memory_desc_t *weights_iter_desc,
2967         const dnnl_memory_desc_t *weights_peephole_desc,
2968         const dnnl_memory_desc_t *weights_projection_desc,
2969         const dnnl_memory_desc_t *bias_desc,
2970         const dnnl_memory_desc_t *dst_layer_desc,
2971         const dnnl_memory_desc_t *dst_iter_desc,
2972         const dnnl_memory_desc_t *dst_iter_c_desc,
2973         const dnnl_memory_desc_t *diff_src_layer_desc,
2974         const dnnl_memory_desc_t *diff_src_iter_desc,
2975         const dnnl_memory_desc_t *diff_src_iter_c_desc,
2976         const dnnl_memory_desc_t *diff_weights_layer_desc,
2977         const dnnl_memory_desc_t *diff_weights_iter_desc,
2978         const dnnl_memory_desc_t *diff_weights_peephole_desc,
2979         const dnnl_memory_desc_t *diff_weights_projection_desc,
2980         const dnnl_memory_desc_t *diff_bias_desc,
2981         const dnnl_memory_desc_t *diff_dst_layer_desc,
2982         const dnnl_memory_desc_t *diff_dst_iter_desc,
2983         const dnnl_memory_desc_t *diff_dst_iter_c_desc, unsigned flags);
2984 
2985 /// Initializes a descriptor for GRU forward propagation primitive.
2986 ///
2987 /// The following arguments may either be @c NULL or point to a zero memory
2988 /// descriptor:
2989 /// - @p src_iter_desc,
2990 /// - @p bias_desc,
2991 /// - @p dst_iter_desc.
2992 ///
2993 /// This would then indicate that the GRU forward propagation primitive should
2994 /// not use them and should default to zero values instead.
2995 ///
2996 /// @note
2997 ///     All memory descriptors can be initialized with
2998 ///     #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
2999 ///
3000 /// @param rnn_desc Output descriptor for GRU primitive.
3001 /// @param prop_kind Propagation kind. Possible values are
3002 ///     #dnnl_forward_training and #dnnl_forward_inference.
3003 /// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
3004 ///     info.
3005 /// @param src_layer_desc Memory descriptor for the input vector.
3006 /// @param src_iter_desc Memory descriptor for the input recurrent hidden
3007 ///     state vector.
3008 /// @param weights_layer_desc Memory descriptor for the weights applied to the
3009 ///     layer input.
3010 /// @param weights_iter_desc Memory descriptor for the weights applied to the
3011 ///     recurrent input.
3012 /// @param bias_desc Bias memory descriptor.
3013 /// @param dst_layer_desc Memory descriptor for the output vector.
3014 /// @param dst_iter_desc Memory descriptor for the output recurrent hidden
3015 ///     state vector.
3016 /// @param flags Unused.
3017 /// @returns #dnnl_success on success and a status describing the error
3018 ///     otherwise.
3019 dnnl_status_t DNNL_API dnnl_gru_forward_desc_init(dnnl_rnn_desc_t *rnn_desc,
3020         dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
3021         const dnnl_memory_desc_t *src_layer_desc,
3022         const dnnl_memory_desc_t *src_iter_desc,
3023         const dnnl_memory_desc_t *weights_layer_desc,
3024         const dnnl_memory_desc_t *weights_iter_desc,
3025         const dnnl_memory_desc_t *bias_desc,
3026         const dnnl_memory_desc_t *dst_layer_desc,
3027         const dnnl_memory_desc_t *dst_iter_desc, unsigned flags);
3028 
3029 /// Initializes a descriptor for GRU backward propagation primitive.
3030 ///
3031 /// The following arguments may either be @c NULL or point to a zero memory
3032 /// descriptor:
3033 /// - @p src_iter_desc together with @p diff_src_iter_desc,
3034 /// - @p bias_desc together with @p diff_bias_desc,
3035 /// - @p dst_iter_desc together with @p diff_dst_iter_desc.
3036 ///
3037 /// This would then indicate that the GRU backward propagation primitive
3038 /// should not use them and should default to zero values instead.
3039 ///
3040 /// @note
3041 ///     All memory descriptors can be initialized with
3042 ///     #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
3043 ///
3044 /// @param rnn_desc Output descriptor for GRU primitive.
3045 /// @param prop_kind Propagation kind. Must be #dnnl_backward.
3046 /// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
3047 ///     info.
3048 /// @param src_layer_desc Memory descriptor for the input vector.
3049 /// @param src_iter_desc Memory descriptor for the input recurrent hidden
3050 ///     state vector.
3051 /// @param weights_layer_desc Memory descriptor for the weights applied to the
3052 ///     layer input.
3053 /// @param weights_iter_desc Memory descriptor for the weights applied to the
3054 ///     recurrent input.
3055 /// @param bias_desc Bias memory descriptor.
3056 /// @param dst_layer_desc Memory descriptor for the output vector.
3057 /// @param dst_iter_desc Memory descriptor for the output recurrent hidden
3058 ///     state vector.
3059 /// @param diff_src_layer_desc Memory descriptor for the diff of input vector.
3060 /// @param diff_src_iter_desc Memory descriptor for the diff of input recurrent
3061 ///     hidden state vector.
3062 /// @param diff_weights_layer_desc Memory descriptor for the diff of weights
3063 ///     applied to the layer input.
3064 /// @param diff_weights_iter_desc Memory descriptor for the diff of weights
3065 ///     applied to the recurrent input.
3066 /// @param diff_bias_desc Diff bias memory descriptor.
3067 /// @param diff_dst_layer_desc Memory descriptor for the diff of output
3068 ///     vector.
3069 /// @param diff_dst_iter_desc Memory descriptor for the diff of output
3070 ///     recurrent hidden state vector.
3071 /// @param flags Unused.
3072 /// @returns #dnnl_success on success and a status describing the error
3073 ///     otherwise.
3074 dnnl_status_t DNNL_API dnnl_gru_backward_desc_init(dnnl_rnn_desc_t *rnn_desc,
3075         dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
3076         const dnnl_memory_desc_t *src_layer_desc,
3077         const dnnl_memory_desc_t *src_iter_desc,
3078         const dnnl_memory_desc_t *weights_layer_desc,
3079         const dnnl_memory_desc_t *weights_iter_desc,
3080         const dnnl_memory_desc_t *bias_desc,
3081         const dnnl_memory_desc_t *dst_layer_desc,
3082         const dnnl_memory_desc_t *dst_iter_desc,
3083         const dnnl_memory_desc_t *diff_src_layer_desc,
3084         const dnnl_memory_desc_t *diff_src_iter_desc,
3085         const dnnl_memory_desc_t *diff_weights_layer_desc,
3086         const dnnl_memory_desc_t *diff_weights_iter_desc,
3087         const dnnl_memory_desc_t *diff_bias_desc,
3088         const dnnl_memory_desc_t *diff_dst_layer_desc,
3089         const dnnl_memory_desc_t *diff_dst_iter_desc, unsigned flags);
3090 
3091 /// Initializes a descriptor for LBR GRU forward propagation primitive.
3092 ///
3093 /// The following arguments may either be @c NULL or point to a zero memory
3094 /// descriptor:
3095 /// - @p src_iter_desc,
3096 /// - @p bias_desc,
3097 /// - @p dst_iter_desc.
3098 ///
3099 /// This would then indicate that the LBR GRU forward propagation primitive
3100 /// should not use them and should default to zero values instead.
3101 ///
3102 /// @param rnn_desc Output descriptor for LBR GRU primitive.
3103 /// @param prop_kind Propagation kind. Possible values are
3104 ///     #dnnl_forward_training and #dnnl_forward_inference.
3105 /// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
3106 ///     info.
3107 /// @param src_layer_desc Memory descriptor for the input vector.
3108 /// @param src_iter_desc Memory descriptor for the input recurrent hidden
3109 ///     state vector.
3110 /// @param weights_layer_desc Memory descriptor for the weights applied to the
3111 ///     layer input.
3112 /// @param weights_iter_desc Memory descriptor for the weights applied to the
3113 ///     recurrent input.
3114 /// @param bias_desc Bias memory descriptor.
3115 /// @param dst_layer_desc Memory descriptor for the output vector.
3116 /// @param dst_iter_desc Memory descriptor for the output recurrent hidden
3117 ///     state vector.
3118 /// @param flags Unused.
3119 /// @returns #dnnl_success on success and a status describing the error
3120 ///     otherwise.
3121 dnnl_status_t DNNL_API dnnl_lbr_gru_forward_desc_init(dnnl_rnn_desc_t *rnn_desc,
3122         dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
3123         const dnnl_memory_desc_t *src_layer_desc,
3124         const dnnl_memory_desc_t *src_iter_desc,
3125         const dnnl_memory_desc_t *weights_layer_desc,
3126         const dnnl_memory_desc_t *weights_iter_desc,
3127         const dnnl_memory_desc_t *bias_desc,
3128         const dnnl_memory_desc_t *dst_layer_desc,
3129         const dnnl_memory_desc_t *dst_iter_desc, unsigned flags);
3130 
3131 /// Initializes a descriptor for LBR GRU backward propagation primitive.
3132 ///
3133 /// The following arguments may either be @c NULL or point to a zero memory
3134 /// descriptor:
3135 /// - @p src_iter_desc together with @p diff_src_iter_desc,
3136 /// - @p bias_desc together with @p diff_bias_desc,
3137 /// - @p dst_iter_desc together with @p diff_dst_iter_desc.
3138 ///
3139 /// This would then indicate that the LBR GRU backward propagation primitive
3140 /// should not use them and should default to zero values instead.
3141 ///
3142 /// @note
3143 ///     All memory descriptors can be initialized with
3144 ///     #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
3145 ///
3146 /// @param rnn_desc Output descriptor for LBR GRU primitive.
3147 /// @param prop_kind Propagation kind. Must be #dnnl_backward.
3148 /// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
3149 ///     info.
3150 /// @param src_layer_desc Memory descriptor for the input vector.
3151 /// @param src_iter_desc Memory descriptor for the input recurrent hidden
3152 ///     state vector.
3153 /// @param weights_layer_desc Memory descriptor for the weights applied to the
3154 ///     layer input.
3155 /// @param weights_iter_desc Memory descriptor for the weights applied to the
3156 ///     recurrent input.
3157 /// @param bias_desc Bias memory descriptor.
3158 /// @param dst_layer_desc Memory descriptor for the output vector.
3159 /// @param dst_iter_desc Memory descriptor for the output recurrent hidden
3160 ///     state vector.
3161 /// @param diff_src_layer_desc Memory descriptor for the diff of input vector.
3162 /// @param diff_src_iter_desc Memory descriptor for the diff of input recurrent
3163 ///     hidden state vector.
3164 /// @param diff_weights_layer_desc Memory descriptor for the diff of weights
3165 ///     applied to the layer input.
3166 /// @param diff_weights_iter_desc Memory descriptor for the diff of weights
3167 ///     applied to the recurrent input.
3168 /// @param diff_bias_desc Diff bias memory descriptor.
3169 /// @param diff_dst_layer_desc Memory descriptor for the diff of output
3170 ///     vector.
3171 /// @param diff_dst_iter_desc Memory descriptor for the diff of output
3172 ///     recurrent hidden state vector.
3173 /// @param flags Unused.
3174 /// @returns #dnnl_success on success and a status describing the error
3175 ///     otherwise.
3176 dnnl_status_t DNNL_API dnnl_lbr_gru_backward_desc_init(
3177         dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind,
3178         dnnl_rnn_direction_t direction,
3179         const dnnl_memory_desc_t *src_layer_desc,
3180         const dnnl_memory_desc_t *src_iter_desc,
3181         const dnnl_memory_desc_t *weights_layer_desc,
3182         const dnnl_memory_desc_t *weights_iter_desc,
3183         const dnnl_memory_desc_t *bias_desc,
3184         const dnnl_memory_desc_t *dst_layer_desc,
3185         const dnnl_memory_desc_t *dst_iter_desc,
3186         const dnnl_memory_desc_t *diff_src_layer_desc,
3187         const dnnl_memory_desc_t *diff_src_iter_desc,
3188         const dnnl_memory_desc_t *diff_weights_layer_desc,
3189         const dnnl_memory_desc_t *diff_weights_iter_desc,
3190         const dnnl_memory_desc_t *diff_bias_desc,
3191         const dnnl_memory_desc_t *diff_dst_layer_desc,
3192         const dnnl_memory_desc_t *diff_dst_iter_desc, unsigned flags);
3193 
3194 /// @} dnnl_api_rnn
3195 
3196 /// @addtogroup dnnl_api_matmul
3197 /// @{
3198 
3199 /// Initializes a matrix multiplication descriptor.
3200 ///
3201 /// @param matmul_desc Output descriptor for matmul primitive.
3202 /// @param src_desc Source memory descriptor (matrix A)
3203 /// @param weights_desc Weights memory descriptor (matrix B)
3204 /// @param bias_desc Bias memory descriptor. Passing NULL, a zero memory
3205 ///     descriptor, or a memory descriptor with format_kind set to
3206 ///     #dnnl_format_kind_undef disables the bias term.
3207 /// @param dst_desc Destination memory descriptor (matrix C).
3208 /// @returns #dnnl_success on success and a status describing the error
3209 ///     otherwise.
3210 dnnl_status_t DNNL_API dnnl_matmul_desc_init(dnnl_matmul_desc_t *matmul_desc,
3211         const dnnl_memory_desc_t *src_desc,
3212         const dnnl_memory_desc_t *weights_desc,
3213         const dnnl_memory_desc_t *bias_desc,
3214         const dnnl_memory_desc_t *dst_desc);
3215 
3216 /// @} dnnl_api_matmul
3217 
3218 /// @addtogroup dnnl_api_resampling Resampling
3219 /// @{
3220 
3221 /// Initializes a descriptor for a resampling forward propagation primitive.
3222 ///
3223 /// @note
3224 ///     Destination memory descriptor is allowed to be initialized with
3225 ///     #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
3226 ///
3227 ///
3228 /// @param resampling_desc Output descriptor for a resampling primitive.
3229 /// @param prop_kind Propagation kind. Possible values are
3230 ///     #dnnl_forward_training and #dnnl_forward_inference.
3231 /// @param alg_kind resampling algorithm kind: either #dnnl_resampling_nearest,
3232 ///     or #dnnl_resampling_linear.
3233 /// @param factors Array of scaling factors for spatial dimension.
3234 /// @param src_desc Source memory descriptor.
3235 /// @param dst_desc Destination memory descriptor.
3236 /// @returns #dnnl_success on success and a status describing the error
3237 ///     otherwise.
3238 dnnl_status_t DNNL_API dnnl_resampling_forward_desc_init(
3239         dnnl_resampling_desc_t *resampling_desc, dnnl_prop_kind_t prop_kind,
3240         dnnl_alg_kind_t alg_kind, const float *factors,
3241         const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *dst_desc);
3242 
3243 /// Initializes a descriptor for resampling backward propagation primitive.
3244 ///
3245 /// @param resampling_desc Output descriptor for a resampling primitive.
3246 /// @param alg_kind resamplinging algorithm kind: either
3247 ///     #dnnl_resampling_nearest, or #dnnl_resampling_linear.
3248 /// @param diff_src_desc Diff source memory descriptor.
3249 /// @param diff_dst_desc Diff destination memory descriptor.
3250 /// @param factors Array of scaling factors for spatial dimension.
3251 /// @returns #dnnl_success on success and a status describing the error
3252 ///     otherwise.
3253 ///
3254 dnnl_status_t DNNL_API dnnl_resampling_backward_desc_init(
3255         dnnl_resampling_desc_t *resampling_desc, dnnl_alg_kind_t alg_kind,
3256         const float *factors, const dnnl_memory_desc_t *diff_src_desc,
3257         const dnnl_memory_desc_t *diff_dst_desc);
3258 
3259 /// @} dnnl_api_resampling
3260 
3261 /// @addtogroup dnnl_api_reduction Reduction
3262 /// @{
3263 
3264 /// Initializes a descriptor for a reduction primitive.
3265 ///
3266 /// @note
3267 ///     Destination memory descriptor is allowed to be initialized with
3268 ///     #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
3269 ///
3270 ///
3271 /// @param desc Output descriptor for a reduction primitive.
3272 /// @param alg_kind reduction algorithm kind. Possible values:
3273 ///     #dnnl_reduction_max, #dnnl_reduction_min, #dnnl_reduction_sum,
3274 ///     #dnnl_reduction_mul, #dnnl_reduction_mean, #dnnl_reduction_norm_lp_max,
3275 ///     #dnnl_reduction_norm_lp_sum, #dnnl_reduction_norm_lp_power_p_max,
3276 ///     #dnnl_reduction_norm_lp_power_p_sum.
3277 /// @param p Algorithm specific parameter.
3278 /// @param eps Algorithm specific parameter.
3279 /// @param src_desc Source memory descriptor.
3280 /// @param dst_desc Destination memory descriptor.
3281 /// @returns #dnnl_success on success and a status describing the error
3282 ///     otherwise.
3283 ///
3284 dnnl_status_t DNNL_API dnnl_reduction_desc_init(dnnl_reduction_desc_t *desc,
3285         dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc,
3286         const dnnl_memory_desc_t *dst_desc, float p, float eps);
3287 
3288 /// @} dnnl_api_reduction
3289 
3290 /// @} dnnl_api_primitives
3291 
3292 /// @addtogroup dnnl_api_engine
3293 /// @{
3294 
3295 /// Returns the number of engines of a particular kind.
3296 ///
3297 /// @param kind Kind of engines to count.
3298 /// @returns Count of the engines.
3299 size_t DNNL_API dnnl_engine_get_count(dnnl_engine_kind_t kind);
3300 
3301 /// Creates an engine.
3302 ///
3303 /// @param engine Output engine.
3304 /// @param kind Engine kind.
3305 /// @param index Engine index that should be between 0 and the count of
3306 ///     engines of the requested kind.
3307 /// @returns #dnnl_success on success and a status describing the error
3308 ///     otherwise.
3309 dnnl_status_t DNNL_API dnnl_engine_create(
3310         dnnl_engine_t *engine, dnnl_engine_kind_t kind, size_t index);
3311 
3312 /// Returns the kind of an engine.
3313 ///
3314 /// @param engine Engine to query.
3315 /// @param kind Output engine kind.
3316 /// @returns #dnnl_success on success and a status describing the error
3317 ///     otherwise.
3318 dnnl_status_t DNNL_API dnnl_engine_get_kind(
3319         dnnl_engine_t engine, dnnl_engine_kind_t *kind);
3320 
3321 /// Destroys an engine.
3322 ///
3323 /// @param engine Engine to destroy.
3324 /// @returns #dnnl_success on success and a status describing the error
3325 ///     otherwise.
3326 dnnl_status_t DNNL_API dnnl_engine_destroy(dnnl_engine_t engine);
3327 
3328 /// @} dnnl_api_engine
3329 
3330 /// @addtogroup dnnl_api_stream
3331 /// @{
3332 
3333 /// Creates an execution stream.
3334 ///
3335 /// @param stream Output execution stream.
3336 /// @param engine Engine to create the execution stream on.
3337 /// @param flags Stream behavior flags (@sa dnnl_stream_flags_t).
3338 /// @returns #dnnl_success on success and a status describing the error
3339 ///     otherwise.
3340 dnnl_status_t DNNL_API dnnl_stream_create(
3341         dnnl_stream_t *stream, dnnl_engine_t engine, unsigned flags);
3342 
3343 /// Returns the engine of a stream object.
3344 ///
3345 /// @param stream Stream object.
3346 /// @param engine Output engine on which the stream is created.
3347 /// @returns #dnnl_success on success and a status describing the error
3348 ///     otherwise.
3349 dnnl_status_t DNNL_API dnnl_stream_get_engine(
3350         const_dnnl_stream_t stream, dnnl_engine_t *engine);
3351 
3352 /// Waits for all primitives in the execution stream to finish computations.
3353 ///
3354 /// @param stream Execution stream.
3355 /// @returns #dnnl_success on success and a status describing the error
3356 ///     otherwise.
3357 dnnl_status_t DNNL_API dnnl_stream_wait(dnnl_stream_t stream);
3358 
3359 /// Destroys an execution stream.
3360 ///
3361 /// @param stream Execution stream to destroy.
3362 /// @returns #dnnl_success on success and a status describing the error
3363 ///     otherwise.
3364 dnnl_status_t DNNL_API dnnl_stream_destroy(dnnl_stream_t stream);
3365 
3366 /// @} dnnl_api_stream
3367 
3368 /// @addtogroup dnnl_api_primitive_cache
3369 /// @{
3370 
3371 /// Returns the number of primitives that can be held in the primitive cache
3372 /// at the same time.
3373 ///
3374 /// @param capacity Primitive cache capacity to query. Concurrently
3375 /// accessing @p capacity is safe.
3376 /// @returns #dnnl_invalid_arguments/#dnnl::status::invalid_arguments if the
3377 ///     @p capacity value is invalid, and #dnnl_success/#dnnl::status::success on
3378 ///     success.
3379 dnnl_status_t DNNL_API dnnl_get_primitive_cache_capacity(int *capacity);
3380 
3381 /// Sets a number of primitives that can be held in the primitive cache
3382 /// at a time.
3383 ///
3384 /// @param capacity Primitive cache capacity to set. If a new @p capacity is
3385 /// less than a number of primitives that the primitive cache already has
3386 /// then the excess entries will be evicted. Setting the @p capacity to 0
3387 /// clears the primitive cache and disables it. Concurrently modifying
3388 /// @p capacity is safe.
3389 /// @returns #dnnl_invalid_arguments/#dnnl::status::invalid_arguments if the
3390 ///     @p capacity value is invalid, and #dnnl_success/#dnnl::status::success on
3391 ///     success.
3392 dnnl_status_t DNNL_API dnnl_set_primitive_cache_capacity(int capacity);
3393 
3394 /// @} dnnl_api_primitive_cache
3395 
3396 /// @addtogroup dnnl_api_service
3397 /// @{
3398 
3399 /// Configures verbose output to stdout.
3400 ///
3401 /// @note
3402 ///     Enabling verbose output affects performance.
3403 ///     This setting overrides the DNNL_VERBOSE environment variable.
3404 ///
3405 /// @param level Verbosity level:
3406 ///  - 0: no verbose output (default),
3407 ///  - 1: primitive information at execution,
3408 ///  - 2: primitive information at creation and execution.
3409 /// @returns #dnnl_invalid_arguments/#dnnl::status::invalid_arguments if the
3410 ///     @p level value is invalid, and #dnnl_success/#dnnl::status::success on
3411 ///     success.
3412 dnnl_status_t DNNL_API dnnl_set_verbose(int level);
3413 
3414 /// Configures dumping of JIT-generated code.
3415 ///
3416 /// @note
3417 ///     This setting overrides the DNNL_JIT_DUMP environment variable.
3418 ///
3419 /// @param enable Flag value. Set to 0 to disable and set to 1 to enable.
3420 /// @returns #dnnl_invalid_arguments/#dnnl::status::invalid_arguments if the
3421 ///     @p flag value is invalid, and #dnnl_success/#dnnl::status::success on
3422 ///     success.
3423 dnnl_status_t DNNL_API dnnl_set_jit_dump(int enable);
3424 
3425 /// Returns library version information.
3426 /// @returns Pointer to a constant structure containing
3427 ///  - major: major version number,
3428 ///  - minor: minor version number,
3429 ///  - patch: patch release number,
3430 ///  - hash: git commit hash.
3431 const dnnl_version_t DNNL_API *dnnl_version(void);
3432 
3433 /// Sets library profiling flags. The flags define which profilers are
3434 /// supported.
3435 ///
3436 /// @note
3437 ///     This setting overrides DNNL_JIT_PROFILE environment variable.
3438 ///
3439 /// @sa @ref dev_guide_profilers
3440 ///
3441 /// @param flags Profiling flags that can contain the following bits:
3442 ///     - @ref DNNL_JIT_PROFILE_VTUNE -- integration with VTune Amplifier
3443 ///         (on by default)
3444 ///     - @ref DNNL_JIT_PROFILE_LINUX_JITDUMP -- produce Linux-specific
3445 ///         jit-pid.dump output (off by default). The location of the output
3446 ///         is controlled via JITDUMPDIR environment variable or via
3447 ///         dnnl_set_jit_profiling_jitdumpdir() function.
3448 ///     - @ref DNNL_JIT_PROFILE_LINUX_PERFMAP -- produce Linux-specific
3449 ///         perf-pid.map output (off by default). The output is always placed
3450 ///         into /tmp.
3451 ///
3452 ///     Passing @ref DNNL_JIT_PROFILE_NONE disables profiling completely.
3453 ///
3454 /// @returns #dnnl_invalid_arguments/#dnnl::status::invalid_arguments if the
3455 ///     @p flags value is invalid, and #dnnl_success/#dnnl::status::success on
3456 ///     success.
3457 dnnl_status_t DNNL_API dnnl_set_jit_profiling_flags(unsigned flags);
3458 
3459 /// Sets JIT dump output path. Only applicable to Linux and is only
3460 /// used when profiling flags have DNNL_JIT_PROFILE_LINUX_PERF bit set.
3461 ///
3462 /// After the first JIT kernel is generated, the jitdump output will be placed
3463 /// into temporary directory created using the mkdtemp template
3464 /// 'dir/.debug/jit/dnnl.XXXXXX'.
3465 ///
3466 /// @sa @ref dev_guide_profilers
3467 ///
3468 /// @note
3469 ///     This setting overrides JITDUMPDIR environment variable.  If
3470 ///     JITDUMPDIR is not set, and this function is never called, the path
3471 ///     defaults to HOME. Passing NULL reverts the value to default.
3472 ///
3473 /// @note
3474 ///     The directory is accessed only when the first JIT kernel is being
3475 ///     created. JIT profiling will be disabled in case of any errors
3476 ///     accessing or creating this directory.
3477 ///
3478 /// @param dir JIT dump output path.
3479 /// @returns #dnnl_success/#dnnl::status::success if the
3480 ///     output directory was set correctly and an error status otherwise.
3481 /// @returns #dnnl_unimplemented/#dnnl::status::unimplemented on Windows.
3482 dnnl_status_t DNNL_API dnnl_set_jit_profiling_jitdumpdir(const char *dir);
3483 
3484 /// Sets the maximal ISA the library can dispatch to on the CPU. See
3485 /// #dnnl_cpu_isa_t and #dnnl::cpu_isa for the list of the values accepted by
3486 /// the C and C++ API functions respectively.
3487 ///
3488 /// This function has effect only once, and returns an error on subsequent
3489 /// calls. It should also be invoked before any other oneDNN API call, otherwise
3490 /// it may return an error.
3491 ///
3492 /// This function overrides the DNNL_MAX_CPU_ISA environment variable. The
3493 /// environment variable can be set to the desired maximal ISA name in upper
3494 /// case and with dnnl_cpu_isa prefix removed. For example:
3495 /// `DNNL_MAX_CPU_ISA=AVX2`.
3496 ///
3497 /// @note
3498 ///     The ISAs are only partially ordered:
3499 ///         - SSE41 < AVX < AVX2,
3500 ///         - AVX2 < AVX512_MIC < AVX512_MIC_4OPS,
3501 ///         - AVX2 < AVX512_CORE < AVX512_CORE_VNNI < AVX512_CORE_BF16
3502 ///           < AVX512_CORE_AMX,
3503 ///         - AVX2 < AVX2_VNNI.
3504 ///
3505 /// @sa @ref dev_guide_cpu_dispatcher_control for more details
3506 ///
3507 /// @param isa Maximal ISA the library should dispatch to. Pass
3508 ///     #dnnl_cpu_isa_all/#dnnl::cpu_isa::all to remove ISA restrictions
3509 ///     (except for ISAs with initial support in the library).
3510 /// @returns #dnnl_success/#dnnl::status::success on success and a
3511 ///     #dnnl_invalid_arguments/#dnnl::status::invalid_arguments if the @p isa
3512 ///     parameter is invalid or the ISA cannot be changed at this time.
3513 /// @returns #dnnl_unimplemented/#dnnl::status::unimplemented if the feature
3514 ///     was disabled at build time (see @ref dev_guide_build_options for more
3515 ///     details).
3516 dnnl_status_t DNNL_API dnnl_set_max_cpu_isa(dnnl_cpu_isa_t isa);
3517 
3518 /// Gets the maximal ISA the library can dispatch to on the CPU. See
3519 /// #dnnl_cpu_isa_t and #dnnl::cpu_isa for the list of the values returned by
3520 /// the C and C++ API functions respectively.
3521 ///
3522 /// @sa @ref dev_guide_cpu_dispatcher_control for more details
3523 ///
3524 /// @returns #dnnl_cpu_isa_t value reflecting the maximal ISA the library may
3525 ///     dispatch to.
3526 dnnl_cpu_isa_t DNNL_API dnnl_get_effective_cpu_isa(void);
3527 
3528 /// Sets the hints flag for the CPU ISA. See #dnnl_cpu_isa_hints_t and
3529 /// #dnnl::cpu_isa_hints for the list of the values accepted by the C and C++
3530 /// API functions respectively.
3531 ///
3532 /// This function has effect only once, and returns an error on subsequent
3533 /// calls. It should also be invoked before any other oneDNN API call, otherwise
3534 /// it may return an error.
3535 ///
3536 /// This function overrides the DNNL_CPU_ISA_HINTS environment variable.
3537 /// @sa @ref dev_guide_cpu_isa_hints for more details
3538 ///
3539 /// @param isa_hints CPU ISA hints to be passed over to the implementation.
3540 ///     Pass #dnnl_cpu_isa_no_hints/#dnnl::cpu_isa_hints::no_hints to use
3541 ///     default features i.e. no hints.
3542 /// @returns #dnnl_success/#dnnl::status::success on success and a
3543 ///     #dnnl_runtime_error/#dnnl::status::runtime_error if the ISA hints cannot
3544 ///     be specified at the current time.
3545 /// @returns #dnnl_unimplemented/#dnnl::status::unimplemented if the feature
3546 ///     was disabled at build time (see @ref dev_guide_build_options for more
3547 ///     details).
3548 dnnl_status_t DNNL_API dnnl_set_cpu_isa_hints(dnnl_cpu_isa_hints_t isa_hints);
3549 
3550 /// Gets the ISA specific hints that library can follow. See
3551 /// #dnnl_cpu_isa_hints_t and #dnnl::cpu_isa_hints for the list of the values
3552 ///  returned by the C and C++ API functions respectively.
3553 ///
3554 /// @sa @ref dev_guide_cpu_isa_hints for more details
3555 ///
3556 /// @returns #dnnl_cpu_isa_hints_t value reflecting the ISA specific hints the
3557 /// library can follow.
3558 dnnl_cpu_isa_hints_t DNNL_API dnnl_get_cpu_isa_hints(void);
3559 
3560 /// @} dnnl_api_service
3561 
3562 /// @addtogroup dnnl_api_blas
3563 /// @{
3564 
3565 /// Performs single-precision matrix-matrix multiply.
3566 ///
3567 /// The operation is defined as:
3568 ///
3569 /// `C := alpha * op( A ) * op( B ) + beta * C`
3570 ///
3571 /// where
3572 ///  - `op( X ) = X` or `op( X ) = X**T`,
3573 ///  - `alpha` and `beta` are scalars, and
3574 ///  - `A`, `B`, and `C` are matrices:
3575 ///     - `op( A )` is an `MxK` matrix,
3576 ///     - `op( B )` is an `KxN` matrix,
3577 ///     - `C` is an `MxN` matrix.
3578 ///
3579 /// The matrices are assumed to be stored in row-major order (the elements in
3580 /// each of the matrix rows are contiguous in memory).
3581 ///
3582 /// @note
3583 ///     This API does not support XERBLA. Instead, unlike the standard BLAS
3584 ///     functions, this one returns a dnnl_status_t value to allow error
3585 ///     handling.
3586 ///
3587 /// @param transa Transposition flag for matrix A: 'N' or 'n' means A is not
3588 ///     transposed, and 'T' or 't' means that A is transposed.
3589 /// @param transb Transposition flag for matrix B: 'N' or 'n' means B is not
3590 ///     transposed, and 'T' or 't' means that B is transposed.
3591 /// @param M The M dimension.
3592 /// @param N The N dimension.
3593 /// @param K The K dimension.
3594 /// @param alpha The alpha parameter that is used to scale the product of
3595 ///     matrices A and B.
3596 /// @param A A pointer to the A matrix data.
3597 /// @param lda The leading dimension for the matrix A.
3598 /// @param B A pointer to the B matrix data.
3599 /// @param ldb The leading dimension for the matrix B.
3600 /// @param beta The beta parameter that is used to scale the matrix C.
3601 /// @param C A pointer to the C matrix data.
3602 /// @param ldc The leading dimension for the matrix C.
3603 /// @returns #dnnl_success/#dnnl::status::success on success and a status
3604 ///     describing the error otherwise.
3605 dnnl_status_t DNNL_API dnnl_sgemm(char transa, char transb, dnnl_dim_t M,
3606         dnnl_dim_t N, dnnl_dim_t K, float alpha, const float *A, dnnl_dim_t lda,
3607         const float *B, dnnl_dim_t ldb, float beta, float *C, dnnl_dim_t ldc);
3608 
3609 /// Performs integer matrix-matrix multiply on 8-bit unsigned matrix A, 8-bit
3610 /// signed matrix B, and 32-bit signed resulting matrix C.
3611 ///
3612 /// The operation is defined as:
3613 ///
3614 /// `C := alpha * (op(A) - A_offset) * (op(B) - B_offset) + beta * C + C_offset`
3615 ///
3616 /// where
3617 ///  - `op( X ) = X` or `op( X ) = X**T`,
3618 ///  - `alpha` and `beta` are scalars, and
3619 ///  - `A`, `B`, and `C` are matrices:
3620 ///     - `op( A )` is an `MxK` matrix,
3621 ///     - `op( B )` is an `KxN` matrix,
3622 ///     - `C` is an `MxN` matrix.
3623 ///  - `A_offset` is an `MxK` matrix with every element equal the `ao` value,
3624 ///  - `B_offset` is an `KxN` matrix with every element equal the `bo` value,
3625 ///  - `C_offset` is an `MxN` matrix which is defined by the `co` array of size `len`:
3626 ///    - if `offsetc = F`: the `len` must be at least `1`,
3627 ///    - if `offsetc = C`: the `len` must be at least `max(1, m)`,
3628 ///    - if `offsetc = R`: the `len` must be at least `max(1, n)`,
3629 ///
3630 /// The matrices are assumed to be stored in row-major order (the elements in
3631 /// each of the matrix rows are contiguous in memory).
3632 ///
3633 /// @note
3634 ///     This API does not support XERBLA. Instead, unlike the standard BLAS
3635 ///     functions, this one returns a dnnl_status_t value to allow error
3636 ///     handling.
3637 ///
3638 /// @warning
3639 ///     On some architectures saturation may happen during intermediate
3640 ///     computations, which would lead to unexpected results. For more
3641 ///     details, refer to @ref dev_guide_int8_computations.
3642 ///
3643 /// @param transa Transposition flag for matrix A: 'N' or 'n' means A is not
3644 ///     transposed, and 'T' or 't' means that A is transposed.
3645 /// @param transb Transposition flag for matrix B: 'N' or 'n' means B is not
3646 ///     transposed, and 'T' or 't' means that B is transposed.
3647 /// @param offsetc Flag specifying how offsets should be applied to matrix C:
3648 ///     - 'F' means that the same offset will be applied to each element of
3649 ///         the matrix C,
3650 ///     - 'C' means that individual offset will be applied to each element
3651 ///         within each column,
3652 ///     - 'R' means that individual offset will be applied to each element
3653 ///         within each row.
3654 /// @param M The M dimension.
3655 /// @param N The N dimension.
3656 /// @param K The K dimension.
3657 /// @param alpha The alpha parameter that is used to scale the product of
3658 ///     matrices A and B.
3659 /// @param A A pointer to the A matrix data.
3660 /// @param lda The leading dimension for the matrix A.
3661 /// @param ao The offset value for the matrix A.
3662 /// @param B A pointer to the B matrix data.
3663 /// @param ldb The leading dimension for the matrix B.
3664 /// @param bo The offset value for the matrix B.
3665 /// @param beta The beta parameter that is used to scale the matrix C.
3666 /// @param C A pointer to the C matrix data.
3667 /// @param ldc The leading dimension for the matrix C.
3668 /// @param co An array of offset values for the matrix C. The number of
3669 ///     elements in the array depends on the value of @p offsetc.
3670 /// @returns #dnnl_success/#dnnl::status::success on success and a status
3671 ///     describing the error otherwise.
3672 dnnl_status_t DNNL_API dnnl_gemm_u8s8s32(char transa, char transb, char offsetc,
3673         dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, float alpha, const uint8_t *A,
3674         dnnl_dim_t lda, uint8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo,
3675         float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co);
3676 
3677 /// Performs integer matrix-matrix multiply on 8-bit signed matrix A, 8-bit
3678 /// signed matrix B, and 32-bit signed resulting matrix C.
3679 ///
3680 /// The operation is defined as:
3681 ///
3682 /// `C := alpha * (op(A) - A_offset) * (op(B) - B_offset) + beta * C + C_offset`
3683 ///
3684 /// where
3685 ///  - `op( X ) = X` or `op( X ) = X**T`,
3686 ///  - `alpha` and `beta` are scalars, and
3687 ///  - `A`, `B`, and `C` are matrices:
3688 ///     - `op( A )` is an `MxK` matrix,
3689 ///     - `op( B )` is an `KxN` matrix,
3690 ///     - `C` is an `MxN` matrix.
3691 ///  - `A_offset` is an `MxK` matrix with every element equal the `ao` value,
3692 ///  - `B_offset` is an `KxN` matrix with every element equal the `bo` value,
3693 ///  - `C_offset` is an `MxN` matrix which is defined by the `co` array of size `len`:
3694 ///    - if `offsetc = F`: the `len` must be at least `1`,
3695 ///    - if `offsetc = C`: the `len` must be at least `max(1, m)`,
3696 ///    - if `offsetc = R`: the `len` must be at least `max(1, n)`,
3697 ///
3698 /// The matrices are assumed to be stored in row-major order (the elements in
3699 /// each of the matrix rows are contiguous in memory).
3700 ///
3701 /// @note
3702 ///     This API does not support XERBLA. Instead, unlike the standard BLAS
3703 ///     functions, this one returns a dnnl_status_t value to allow error
3704 ///     handling.
3705 ///
3706 /// @warning
3707 ///     On some architectures saturation may happen during intermediate
3708 ///     computations, which would lead to unexpected results. For more
3709 ///     details, refer to @ref dev_guide_int8_computations.
3710 ///
3711 /// @param transa Transposition flag for matrix A: 'N' or 'n' means A is not
3712 ///     transposed, and 'T' or 't' means that A is transposed.
3713 /// @param transb Transposition flag for matrix B: 'N' or 'n' means B is not
3714 ///     transposed, and 'T' or 't' means that B is transposed.
3715 /// @param offsetc Flag specifying how offsets should be applied to matrix C:
3716 ///     - 'F' means that the same offset will be applied to each element of
3717 ///         the matrix C,
3718 ///     - 'C' means that individual offset will be applied to each element
3719 ///         within each column,
3720 ///     - 'R' means that individual offset will be applied to each element
3721 ///         within each row.
3722 /// @param M The M dimension.
3723 /// @param N The N dimension.
3724 /// @param K The K dimension.
3725 /// @param alpha The alpha parameter that is used to scale the product of
3726 ///     matrices A and B.
3727 /// @param A A pointer to the A matrix data.
3728 /// @param lda The leading dimension for the matrix A.
3729 /// @param ao The offset value for the matrix A.
3730 /// @param B A pointer to the B matrix data.
3731 /// @param ldb The leading dimension for the matrix B.
3732 /// @param bo The offset value for the matrix B.
3733 /// @param beta The beta parameter that is used to scale the matrix C.
3734 /// @param C A pointer to the C matrix data.
3735 /// @param ldc The leading dimension for the matrix C.
3736 /// @param co An array of offset values for the matrix C. The number of
3737 ///     elements in the array depends on the value of @p offsetc.
3738 /// @returns #dnnl_success/#dnnl::status::success on success and a status
3739 ///     describing the error otherwise.
3740 dnnl_status_t DNNL_API dnnl_gemm_s8s8s32(char transa, char transb, char offsetc,
3741         dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, float alpha, const int8_t *A,
3742         dnnl_dim_t lda, int8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo,
3743         float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co);
3744 
3745 /// @} dnnl_api_blas
3746 
3747 /// @} dnnl_api
3748 
3749 #ifdef __cplusplus
3750 }
3751 #endif
3752 
3753 #endif /* ONEAPI_DNNL_DNNL_H */
3754