1 /*******************************************************************************
2 * Copyright 2018-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 /// @example cpu_rnn_inference_f32.cpp
18 /// @copybrief cpu_rnn_inference_f32_cpp
19 /// > Annotated version: @ref cpu_rnn_inference_f32_cpp
20 
21 /// @page cpu_rnn_inference_f32_cpp RNN f32 inference example
22 /// This C++ API example demonstrates how to build GNMT model inference.
23 ///
24 /// > Example code: @ref cpu_rnn_inference_f32.cpp
25 ///
26 /// For the encoder we use:
27 ///  - one primitive for the bidirectional layer of the encoder
28 ///  - one primitive for all remaining unidirectional layers in the encoder
29 /// For the decoder we use:
30 ///  - one primitive for the first iteration
31 ///  - one primitive for all subsequent iterations in the decoder. Note that
32 ///    in this example, this primitive computes the states in place.
33 ///  - the attention mechanism is implemented separately as there is no support
34 ///    for the context vectors in oneDNN yet
35 
36 #include <assert.h>
37 
38 #include <cstring>
39 #include <iostream>
40 #include <math.h>
41 #include <numeric>
42 #include <string>
43 
44 #include "oneapi/dnnl/dnnl.hpp"
45 
46 #include "example_utils.hpp"
47 
48 using namespace dnnl;
49 
50 using dim_t = dnnl::memory::dim;
51 
52 const dim_t batch = 32;
53 const dim_t src_seq_length_max = 10;
54 const dim_t tgt_seq_length_max = 10;
55 
56 const dim_t feature_size = 256;
57 
58 const dim_t enc_bidir_n_layers = 1;
59 const dim_t enc_unidir_n_layers = 3;
60 const dim_t dec_n_layers = 4;
61 
62 const int lstm_n_gates = 4;
63 std::vector<float> weighted_src_layer(batch *feature_size, 1.0f);
64 std::vector<float> alignment_model(
65         src_seq_length_max *batch *feature_size, 1.0f);
66 std::vector<float> alignments(src_seq_length_max *batch, 1.0f);
67 std::vector<float> exp_sums(batch, 1.0f);
68 
compute_weighted_annotations(float * weighted_annotations,dim_t src_seq_length_max,dim_t batch,dim_t feature_size,float * weights_annot,float * annotations)69 void compute_weighted_annotations(float *weighted_annotations,
70         dim_t src_seq_length_max, dim_t batch, dim_t feature_size,
71         float *weights_annot, float *annotations) {
72     // annotations(aka enc_dst_layer) is (t, n, 2c)
73     // weights_annot is (2c, c)
74 
75     // annotation[i] = GEMM(weights_annot, enc_dst_layer[i]);
76     dim_t num_weighted_annotations = src_seq_length_max * batch;
77     dnnl_sgemm('N', 'N', num_weighted_annotations, feature_size, feature_size,
78             1.f, annotations, feature_size, weights_annot, feature_size, 0.f,
79             weighted_annotations, feature_size);
80 }
81 
compute_attention(float * context_vectors,dim_t src_seq_length_max,dim_t batch,dim_t feature_size,float * weights_src_layer,float * dec_src_layer,float * annotations,float * weighted_annotations,float * weights_alignments)82 void compute_attention(float *context_vectors, dim_t src_seq_length_max,
83         dim_t batch, dim_t feature_size, float *weights_src_layer,
84         float *dec_src_layer, float *annotations, float *weighted_annotations,
85         float *weights_alignments) {
86     // dst_iter : (n, c) matrix
87     // src_layer: (n, c) matrix
88     // weighted_annotations (t, n, c)
89 
90     // weights_yi is (c, c)
91     // weights_ai is (c, 1)
92     // tmp[i] is (n, c)
93     // a[i] is (n, 1)
94     // p is (n, 1)
95 
96     // first we precompute the weighted_dec_src_layer
97     dnnl_sgemm('N', 'N', batch, feature_size, feature_size, 1.f, dec_src_layer,
98             feature_size, weights_src_layer, feature_size, 0.f,
99             weighted_src_layer.data(), feature_size);
100 
101     // then we compute the alignment model
102     float *alignment_model_ptr = alignment_model.data();
103 
104     PRAGMA_OMP_PARALLEL_FOR_COLLAPSE(2)
105     for (dim_t i = 0; i < src_seq_length_max; i++) {
106         for (dim_t j = 0; j < batch * feature_size; j++)
107             alignment_model_ptr[i * batch * feature_size + j] = tanhf(
108                     weighted_src_layer[j]
109                     + weighted_annotations[i * batch * feature_size + j]);
110     }
111 
112     // gemv with alignments weights. the resulting alignments are in alignments
113     dim_t num_weighted_annotations = src_seq_length_max * batch;
114     dnnl_sgemm('N', 'N', num_weighted_annotations, 1, feature_size, 1.f,
115             alignment_model_ptr, feature_size, weights_alignments, 1, 0.f,
116             alignments.data(), 1);
117 
118     // softmax on alignments. the resulting context weights are in alignments
119     PRAGMA_OMP_PARALLEL_FOR_COLLAPSE(1)
120     for (dim_t i = 0; i < batch; i++)
121         exp_sums[i] = 0.0f;
122 
123     PRAGMA_OMP_PARALLEL_FOR_COLLAPSE(1)
124     for (dim_t j = 0; j < batch; j++) {
125         for (dim_t i = 0; i < src_seq_length_max; i++) {
126             alignments[i * batch + j] = expf(alignments[i * batch + j]);
127             exp_sums[j] += alignments[i * batch + j];
128         }
129     }
130 
131     PRAGMA_OMP_PARALLEL_FOR_COLLAPSE(2)
132     for (dim_t i = 0; i < src_seq_length_max; i++)
133         for (dim_t j = 0; j < batch; j++)
134             alignments[i * batch + j] /= exp_sums[j];
135 
136     // then we compute the context vectors
137     PRAGMA_OMP_PARALLEL_FOR_COLLAPSE(2)
138     for (dim_t i = 0; i < batch; i++)
139         for (dim_t j = 0; j < feature_size; j++)
140             context_vectors[i * (feature_size + feature_size) + feature_size
141                     + j]
142                     = 0.0f;
143 
144     PRAGMA_OMP_PARALLEL_FOR_COLLAPSE(2)
145     for (dim_t i = 0; i < batch; i++)
146         for (dim_t j = 0; j < feature_size; j++)
147             for (dim_t k = 0; k < src_seq_length_max; k++)
148                 context_vectors[i * (feature_size + feature_size) + feature_size
149                         + j]
150                         += alignments[k * batch + i]
151                         * annotations[j + feature_size * (i + batch * k)];
152 }
153 
copy_context(float * src_iter,dim_t n_layers,dim_t batch,dim_t feature_size)154 void copy_context(
155         float *src_iter, dim_t n_layers, dim_t batch, dim_t feature_size) {
156     // we copy the context from the first layer to all other layers
157     PRAGMA_OMP_PARALLEL_FOR_COLLAPSE(3)
158     for (dim_t k = 1; k < n_layers; k++)
159         for (dim_t j = 0; j < batch; j++)
160             for (dim_t i = 0; i < feature_size; i++)
161                 src_iter[(k * batch + j) * (feature_size + feature_size)
162                         + feature_size + i]
163                         = src_iter[j * (feature_size + feature_size)
164                                 + feature_size + i];
165 }
166 
simple_net()167 void simple_net() {
168     ///
169     /// Initialize a CPU engine and stream. The last parameter in the call represents
170     /// the index of the engine.
171     /// @snippet cpu_rnn_inference_f32.cpp Initialize engine and stream
172     ///
173     //[Initialize engine and stream]
174     auto cpu_engine = engine(engine::kind::cpu, 0);
175     stream s(cpu_engine);
176     //[Initialize engine and stream]
177     ///
178     /// Declare encoder net and decoder net
179     /// @snippet cpu_rnn_inference_f32.cpp declare net
180     ///
181     //[declare net]
182     std::vector<primitive> encoder_net, decoder_net;
183     std::vector<std::unordered_map<int, memory>> encoder_net_args,
184             decoder_net_args;
185 
186     std::vector<float> net_src(batch * src_seq_length_max * feature_size, 1.0f);
187     std::vector<float> net_dst(batch * tgt_seq_length_max * feature_size, 1.0f);
188     //[declare net]
189     ///
190     /// **Encoder**
191     ///
192     ///
193     /// Initialize Encoder Memory
194     /// @snippet cpu_rnn_inference_f32.cpp Initialize encoder memory
195     ///
196     //[Initialize encoder memory]
197     memory::dims enc_bidir_src_layer_tz
198             = {src_seq_length_max, batch, feature_size};
199     memory::dims enc_bidir_weights_layer_tz
200             = {enc_bidir_n_layers, 2, feature_size, lstm_n_gates, feature_size};
201     memory::dims enc_bidir_weights_iter_tz
202             = {enc_bidir_n_layers, 2, feature_size, lstm_n_gates, feature_size};
203     memory::dims enc_bidir_bias_tz
204             = {enc_bidir_n_layers, 2, lstm_n_gates, feature_size};
205     memory::dims enc_bidir_dst_layer_tz
206             = {src_seq_length_max, batch, 2 * feature_size};
207     //[Initialize encoder memory]
208 
209     ///
210     ///
211     /// Encoder: 1 bidirectional layer and 7 unidirectional layers
212     ///
213 
214     std::vector<float> user_enc_bidir_wei_layer(
215             enc_bidir_n_layers * 2 * feature_size * lstm_n_gates * feature_size,
216             1.0f);
217     std::vector<float> user_enc_bidir_wei_iter(
218             enc_bidir_n_layers * 2 * feature_size * lstm_n_gates * feature_size,
219             1.0f);
220     std::vector<float> user_enc_bidir_bias(
221             enc_bidir_n_layers * 2 * lstm_n_gates * feature_size, 1.0f);
222 
223     ///
224     /// Create the memory for user data
225     /// @snippet cpu_rnn_inference_f32.cpp data memory creation
226     ///
227     //[data memory creation]
228     auto user_enc_bidir_src_layer_md = dnnl::memory::desc(
229             {enc_bidir_src_layer_tz}, dnnl::memory::data_type::f32,
230             dnnl::memory::format_tag::tnc);
231 
232     auto user_enc_bidir_wei_layer_md = dnnl::memory::desc(
233             {enc_bidir_weights_layer_tz}, dnnl::memory::data_type::f32,
234             dnnl::memory::format_tag::ldigo);
235 
236     auto user_enc_bidir_wei_iter_md = dnnl::memory::desc(
237             {enc_bidir_weights_iter_tz}, dnnl::memory::data_type::f32,
238             dnnl::memory::format_tag::ldigo);
239 
240     auto user_enc_bidir_bias_md = dnnl::memory::desc({enc_bidir_bias_tz},
241             dnnl::memory::data_type::f32, dnnl::memory::format_tag::ldgo);
242 
243     auto user_enc_bidir_src_layer_memory = dnnl::memory(
244             user_enc_bidir_src_layer_md, cpu_engine, net_src.data());
245     auto user_enc_bidir_wei_layer_memory
246             = dnnl::memory(user_enc_bidir_wei_layer_md, cpu_engine,
247                     user_enc_bidir_wei_layer.data());
248     auto user_enc_bidir_wei_iter_memory
249             = dnnl::memory(user_enc_bidir_wei_iter_md, cpu_engine,
250                     user_enc_bidir_wei_iter.data());
251     auto user_enc_bidir_bias_memory = dnnl::memory(
252             user_enc_bidir_bias_md, cpu_engine, user_enc_bidir_bias.data());
253 
254     //[data memory creation]
255     ///
256     /// Create memory descriptors for RNN data w/o specified layout
257     /// @snippet cpu_rnn_inference_f32.cpp memory desc for RNN data
258     ///
259     //[memory desc for RNN data]
260     auto enc_bidir_wei_layer_md = memory::desc({enc_bidir_weights_layer_tz},
261             memory::data_type::f32, memory::format_tag::any);
262 
263     auto enc_bidir_wei_iter_md = memory::desc({enc_bidir_weights_iter_tz},
264             memory::data_type::f32, memory::format_tag::any);
265 
266     auto enc_bidir_dst_layer_md = memory::desc({enc_bidir_dst_layer_tz},
267             memory::data_type::f32, memory::format_tag::any);
268 
269     //[memory desc for RNN data]
270     ///
271     /// Create bidirectional RNN
272     /// @snippet cpu_rnn_inference_f32.cpp create rnn
273     ///
274     //[create rnn]
275 
276     lstm_forward::desc bi_layer_desc(prop_kind::forward_inference,
277             rnn_direction::bidirectional_concat, user_enc_bidir_src_layer_md,
278             memory::desc(), memory::desc(), enc_bidir_wei_layer_md,
279             enc_bidir_wei_iter_md, user_enc_bidir_bias_md,
280             enc_bidir_dst_layer_md, memory::desc(), memory::desc());
281 
282     auto enc_bidir_prim_desc
283             = dnnl::lstm_forward::primitive_desc(bi_layer_desc, cpu_engine);
284     //[create rnn]
285 
286     ///
287     /// Create memory for input data and use reorders to reorder user data
288     /// to internal representation
289     /// @snippet cpu_rnn_inference_f32.cpp reorder input data
290     ///
291     //[reorder input data]
292     auto enc_bidir_wei_layer_memory
293             = memory(enc_bidir_prim_desc.weights_layer_desc(), cpu_engine);
294     auto enc_bidir_wei_layer_reorder_pd = reorder::primitive_desc(
295             user_enc_bidir_wei_layer_memory, enc_bidir_wei_layer_memory);
296     reorder(enc_bidir_wei_layer_reorder_pd)
297             .execute(s, user_enc_bidir_wei_layer_memory,
298                     enc_bidir_wei_layer_memory);
299     //[reorder input data]
300 
301     auto enc_bidir_wei_iter_memory
302             = memory(enc_bidir_prim_desc.weights_iter_desc(), cpu_engine);
303     auto enc_bidir_wei_iter_reorder_pd = reorder::primitive_desc(
304             user_enc_bidir_wei_iter_memory, enc_bidir_wei_iter_memory);
305     reorder(enc_bidir_wei_iter_reorder_pd)
306             .execute(s, user_enc_bidir_wei_iter_memory,
307                     enc_bidir_wei_iter_memory);
308 
309     auto enc_bidir_dst_layer_memory
310             = dnnl::memory(enc_bidir_prim_desc.dst_layer_desc(), cpu_engine);
311 
312     ///
313     /// Encoder : add the bidirectional rnn primitive with related arguments into encoder_net
314     /// @snippet cpu_rnn_inference_f32.cpp push bi rnn to encoder net
315     ///
316     //[push bi rnn to encoder net]
317     encoder_net.push_back(lstm_forward(enc_bidir_prim_desc));
318     encoder_net_args.push_back(
319             {{DNNL_ARG_SRC_LAYER, user_enc_bidir_src_layer_memory},
320                     {DNNL_ARG_WEIGHTS_LAYER, enc_bidir_wei_layer_memory},
321                     {DNNL_ARG_WEIGHTS_ITER, enc_bidir_wei_iter_memory},
322                     {DNNL_ARG_BIAS, user_enc_bidir_bias_memory},
323                     {DNNL_ARG_DST_LAYER, enc_bidir_dst_layer_memory}});
324     //[push bi rnn to encoder net]
325 
326     ///
327     /// Encoder: unidirectional layers
328     ///
329     ///
330     /// First unidirectinal layer scales 2 * feature_size output of bidirectional
331     /// layer to feature_size output
332     /// @snippet cpu_rnn_inference_f32.cpp first uni layer
333     ///
334     //[first uni layer]
335     std::vector<float> user_enc_uni_first_wei_layer(
336             1 * 1 * 2 * feature_size * lstm_n_gates * feature_size, 1.0f);
337     std::vector<float> user_enc_uni_first_wei_iter(
338             1 * 1 * feature_size * lstm_n_gates * feature_size, 1.0f);
339     std::vector<float> user_enc_uni_first_bias(
340             1 * 1 * lstm_n_gates * feature_size, 1.0f);
341     //[first uni layer]
342     memory::dims user_enc_uni_first_wei_layer_dims
343             = {1, 1, 2 * feature_size, lstm_n_gates, feature_size};
344     memory::dims user_enc_uni_first_wei_iter_dims
345             = {1, 1, feature_size, lstm_n_gates, feature_size};
346     memory::dims user_enc_uni_first_bias_dims
347             = {1, 1, lstm_n_gates, feature_size};
348     memory::dims enc_uni_first_dst_layer_dims
349             = {src_seq_length_max, batch, feature_size};
350     auto user_enc_uni_first_wei_layer_md = dnnl::memory::desc(
351             {user_enc_uni_first_wei_layer_dims}, dnnl::memory::data_type::f32,
352             dnnl::memory::format_tag::ldigo);
353     auto user_enc_uni_first_wei_iter_md = dnnl::memory::desc(
354             {user_enc_uni_first_wei_iter_dims}, dnnl::memory::data_type::f32,
355             dnnl::memory::format_tag::ldigo);
356     auto user_enc_uni_first_bias_md = dnnl::memory::desc(
357             {user_enc_uni_first_bias_dims}, dnnl::memory::data_type::f32,
358             dnnl::memory::format_tag::ldgo);
359     auto user_enc_uni_first_wei_layer_memory
360             = dnnl::memory(user_enc_uni_first_wei_layer_md, cpu_engine,
361                     user_enc_uni_first_wei_layer.data());
362     auto user_enc_uni_first_wei_iter_memory
363             = dnnl::memory(user_enc_uni_first_wei_iter_md, cpu_engine,
364                     user_enc_uni_first_wei_iter.data());
365     auto user_enc_uni_first_bias_memory
366             = dnnl::memory(user_enc_uni_first_bias_md, cpu_engine,
367                     user_enc_uni_first_bias.data());
368 
369     auto enc_uni_first_wei_layer_md
370             = memory::desc({user_enc_uni_first_wei_layer_dims},
371                     memory::data_type::f32, memory::format_tag::any);
372     auto enc_uni_first_wei_iter_md
373             = memory::desc({user_enc_uni_first_wei_iter_dims},
374                     memory::data_type::f32, memory::format_tag::any);
375     auto enc_uni_first_dst_layer_md
376             = memory::desc({enc_uni_first_dst_layer_dims},
377                     memory::data_type::f32, memory::format_tag::any);
378 
379     // TODO: add support for residual connections
380     // should it be a set residual in op_desc or a field to set manually?
381     // should be an integer to specify at which layer to start
382     ///
383     /// Encoder : Create unidirection RNN for first cell
384     /// @snippet cpu_rnn_inference_f32.cpp create uni first
385     ///
386     //[create uni first]
387     lstm_forward::desc enc_uni_first_layer_desc(prop_kind::forward_inference,
388             rnn_direction::unidirectional_left2right, enc_bidir_dst_layer_md,
389             memory::desc(), memory::desc(), enc_uni_first_wei_layer_md,
390             enc_uni_first_wei_iter_md, user_enc_uni_first_bias_md,
391             enc_uni_first_dst_layer_md, memory::desc(), memory::desc());
392 
393     auto enc_uni_first_prim_desc = dnnl::lstm_forward::primitive_desc(
394             enc_uni_first_layer_desc, cpu_engine);
395 
396     //[create uni first]
397     auto enc_uni_first_wei_layer_memory
398             = memory(enc_uni_first_prim_desc.weights_layer_desc(), cpu_engine);
399     auto enc_uni_first_wei_layer_reorder_pd
400             = reorder::primitive_desc(user_enc_uni_first_wei_layer_memory,
401                     enc_uni_first_wei_layer_memory);
402     reorder(enc_uni_first_wei_layer_reorder_pd)
403             .execute(s, user_enc_uni_first_wei_layer_memory,
404                     enc_uni_first_wei_layer_memory);
405 
406     auto enc_uni_first_wei_iter_memory
407             = memory(enc_uni_first_prim_desc.weights_iter_desc(), cpu_engine);
408     auto enc_uni_first_wei_iter_reorder_pd = reorder::primitive_desc(
409             user_enc_uni_first_wei_iter_memory, enc_uni_first_wei_iter_memory);
410     reorder(enc_uni_first_wei_iter_reorder_pd)
411             .execute(s, user_enc_uni_first_wei_iter_memory,
412                     enc_uni_first_wei_iter_memory);
413 
414     auto enc_uni_first_dst_layer_memory = dnnl::memory(
415             enc_uni_first_prim_desc.dst_layer_desc(), cpu_engine);
416 
417     /// Encoder : add the first unidirectional rnn primitive with related
418     /// arguments into encoder_net
419     ///
420     /// @snippet cpu_rnn_inference_f32.cpp push first uni rnn to encoder net
421     ///
422     //[push first uni rnn to encoder net]
423     // TODO: add a reorder when they will be available
424     encoder_net.push_back(lstm_forward(enc_uni_first_prim_desc));
425     encoder_net_args.push_back(
426             {{DNNL_ARG_SRC_LAYER, enc_bidir_dst_layer_memory},
427                     {DNNL_ARG_WEIGHTS_LAYER, enc_uni_first_wei_layer_memory},
428                     {DNNL_ARG_WEIGHTS_ITER, enc_uni_first_wei_iter_memory},
429                     {DNNL_ARG_BIAS, user_enc_uni_first_bias_memory},
430                     {DNNL_ARG_DST_LAYER, enc_uni_first_dst_layer_memory}});
431     //[push first uni rnn to encoder net]
432 
433     ///
434     /// Encoder : Remaining unidirectional layers
435     /// @snippet cpu_rnn_inference_f32.cpp remaining uni layers
436     ///
437     //[remaining uni layers]
438     std::vector<float> user_enc_uni_wei_layer((enc_unidir_n_layers - 1) * 1
439                     * feature_size * lstm_n_gates * feature_size,
440             1.0f);
441     std::vector<float> user_enc_uni_wei_iter((enc_unidir_n_layers - 1) * 1
442                     * feature_size * lstm_n_gates * feature_size,
443             1.0f);
444     std::vector<float> user_enc_uni_bias(
445             (enc_unidir_n_layers - 1) * 1 * lstm_n_gates * feature_size, 1.0f);
446     //[remaining uni layers]
447     memory::dims user_enc_uni_wei_layer_dims = {(enc_unidir_n_layers - 1), 1,
448             feature_size, lstm_n_gates, feature_size};
449     memory::dims user_enc_uni_wei_iter_dims = {(enc_unidir_n_layers - 1), 1,
450             feature_size, lstm_n_gates, feature_size};
451     memory::dims user_enc_uni_bias_dims
452             = {(enc_unidir_n_layers - 1), 1, lstm_n_gates, feature_size};
453     memory::dims enc_dst_layer_dims = {src_seq_length_max, batch, feature_size};
454     auto user_enc_uni_wei_layer_md = dnnl::memory::desc(
455             {user_enc_uni_wei_layer_dims}, dnnl::memory::data_type::f32,
456             dnnl::memory::format_tag::ldigo);
457     auto user_enc_uni_wei_iter_md = dnnl::memory::desc(
458             {user_enc_uni_wei_iter_dims}, dnnl::memory::data_type::f32,
459             dnnl::memory::format_tag::ldigo);
460     auto user_enc_uni_bias_md = dnnl::memory::desc({user_enc_uni_bias_dims},
461             dnnl::memory::data_type::f32, dnnl::memory::format_tag::ldgo);
462     auto user_enc_uni_wei_layer_memory = dnnl::memory(user_enc_uni_wei_layer_md,
463             cpu_engine, user_enc_uni_wei_layer.data());
464     auto user_enc_uni_wei_iter_memory = dnnl::memory(
465             user_enc_uni_wei_iter_md, cpu_engine, user_enc_uni_wei_iter.data());
466     auto user_enc_uni_bias_memory = dnnl::memory(
467             user_enc_uni_bias_md, cpu_engine, user_enc_uni_bias.data());
468 
469     auto enc_uni_wei_layer_md = memory::desc({user_enc_uni_wei_layer_dims},
470             memory::data_type::f32, memory::format_tag::any);
471     auto enc_uni_wei_iter_md = memory::desc({user_enc_uni_wei_iter_dims},
472             memory::data_type::f32, memory::format_tag::any);
473     auto enc_dst_layer_md = memory::desc({enc_dst_layer_dims},
474             memory::data_type::f32, memory::format_tag::any);
475 
476     // TODO: add support for residual connections
477     // should it be a set residual in op_desc or a field to set manually?
478     // should be an integer to specify at which layer to start
479     ///
480     /// Encoder : Create unidirection RNN cell
481     /// @snippet cpu_rnn_inference_f32.cpp create uni rnn
482     ///
483     //[create uni rnn]
484     lstm_forward::desc enc_uni_layer_desc(prop_kind::forward_inference,
485             rnn_direction::unidirectional_left2right,
486             enc_uni_first_dst_layer_md, memory::desc(), memory::desc(),
487             enc_uni_wei_layer_md, enc_uni_wei_iter_md, user_enc_uni_bias_md,
488             enc_dst_layer_md, memory::desc(), memory::desc());
489     auto enc_uni_prim_desc = dnnl::lstm_forward::primitive_desc(
490             enc_uni_layer_desc, cpu_engine);
491     //[create uni rnn]
492 
493     auto enc_uni_wei_layer_memory
494             = memory(enc_uni_prim_desc.weights_layer_desc(), cpu_engine);
495     auto enc_uni_wei_layer_reorder_pd = reorder::primitive_desc(
496             user_enc_uni_wei_layer_memory, enc_uni_wei_layer_memory);
497     reorder(enc_uni_wei_layer_reorder_pd)
498             .execute(
499                     s, user_enc_uni_wei_layer_memory, enc_uni_wei_layer_memory);
500 
501     auto enc_uni_wei_iter_memory
502             = memory(enc_uni_prim_desc.weights_iter_desc(), cpu_engine);
503     auto enc_uni_wei_iter_reorder_pd = reorder::primitive_desc(
504             user_enc_uni_wei_iter_memory, enc_uni_wei_iter_memory);
505     reorder(enc_uni_wei_iter_reorder_pd)
506             .execute(s, user_enc_uni_wei_iter_memory, enc_uni_wei_iter_memory);
507 
508     auto enc_dst_layer_memory
509             = dnnl::memory(enc_uni_prim_desc.dst_layer_desc(), cpu_engine);
510 
511     // TODO: add a reorder when they will be available
512     ///
513     /// Encoder : add the unidirectional rnn primitive with related arguments into encoder_net
514     /// @snippet cpu_rnn_inference_f32.cpp push uni rnn to encoder net
515     ///
516     //[push uni rnn to encoder net]
517     encoder_net.push_back(lstm_forward(enc_uni_prim_desc));
518     encoder_net_args.push_back(
519             {{DNNL_ARG_SRC_LAYER, enc_uni_first_dst_layer_memory},
520                     {DNNL_ARG_WEIGHTS_LAYER, enc_uni_wei_layer_memory},
521                     {DNNL_ARG_WEIGHTS_ITER, enc_uni_wei_iter_memory},
522                     {DNNL_ARG_BIAS, user_enc_uni_bias_memory},
523                     {DNNL_ARG_DST_LAYER, enc_dst_layer_memory}});
524     //[push uni rnn to encoder net]
525     ///
526     /// **Decoder with attention mechanism**
527     ///
528     ///
529     /// Decoder : declare memory dimensions
530     /// @snippet cpu_rnn_inference_f32.cpp dec mem dim
531     ///
532     //[dec mem dim]
533     std::vector<float> user_dec_wei_layer(
534             dec_n_layers * 1 * feature_size * lstm_n_gates * feature_size,
535             1.0f);
536     std::vector<float> user_dec_wei_iter(dec_n_layers * 1
537                     * (feature_size + feature_size) * lstm_n_gates
538                     * feature_size,
539             1.0f);
540     std::vector<float> user_dec_bias(
541             dec_n_layers * 1 * lstm_n_gates * feature_size, 1.0f);
542     std::vector<float> user_dec_dst(
543             tgt_seq_length_max * batch * feature_size, 1.0f);
544     std::vector<float> user_weights_attention_src_layer(
545             feature_size * feature_size, 1.0f);
546     std::vector<float> user_weights_annotation(
547             feature_size * feature_size, 1.0f);
548     std::vector<float> user_weights_alignments(feature_size, 1.0f);
549 
550     memory::dims user_dec_wei_layer_dims
551             = {dec_n_layers, 1, feature_size, lstm_n_gates, feature_size};
552     memory::dims user_dec_wei_iter_dims = {dec_n_layers, 1,
553             feature_size + feature_size, lstm_n_gates, feature_size};
554     memory::dims user_dec_bias_dims
555             = {dec_n_layers, 1, lstm_n_gates, feature_size};
556 
557     memory::dims dec_src_layer_dims = {1, batch, feature_size};
558     memory::dims dec_dst_layer_dims = {1, batch, feature_size};
559     memory::dims dec_dst_iter_c_dims = {dec_n_layers, 1, batch, feature_size};
560     //[dec mem dim]
561 
562     /// We will use the same memory for dec_src_iter and dec_dst_iter
563     /// However, dec_src_iter has a context vector but not
564     /// dec_dst_iter.
565     /// To resolve this we will create one memory that holds the
566     /// context vector as well as the both the hidden and cell states.
567     /// The dst_iter will be a sub-memory of this memory.
568     /// Note that the cell state will be padded by
569     /// feature_size values. However, we do not compute or
570     /// access those.
571     /// @snippet cpu_rnn_inference_f32.cpp noctx mem dim
572     //[noctx mem dim]
573     memory::dims dec_dst_iter_dims
574             = {dec_n_layers, 1, batch, feature_size + feature_size};
575     memory::dims dec_dst_iter_noctx_dims
576             = {dec_n_layers, 1, batch, feature_size};
577     //[noctx mem dim]
578 
579     ///
580     /// Decoder : create memory description
581     /// @snippet cpu_rnn_inference_f32.cpp dec mem desc
582     ///
583     //[dec mem desc]
584     auto user_dec_wei_layer_md = dnnl::memory::desc({user_dec_wei_layer_dims},
585             dnnl::memory::data_type::f32, dnnl::memory::format_tag::ldigo);
586     auto user_dec_wei_iter_md = dnnl::memory::desc({user_dec_wei_iter_dims},
587             dnnl::memory::data_type::f32, dnnl::memory::format_tag::ldigo);
588     auto user_dec_bias_md = dnnl::memory::desc({user_dec_bias_dims},
589             dnnl::memory::data_type::f32, dnnl::memory::format_tag::ldgo);
590     auto dec_dst_layer_md = dnnl::memory::desc({dec_dst_layer_dims},
591             dnnl::memory::data_type::f32, dnnl::memory::format_tag::tnc);
592     auto dec_src_layer_md = dnnl::memory::desc({dec_src_layer_dims},
593             dnnl::memory::data_type::f32, dnnl::memory::format_tag::tnc);
594     auto dec_dst_iter_md = dnnl::memory::desc({dec_dst_iter_dims},
595             dnnl::memory::data_type::f32, dnnl::memory::format_tag::ldnc);
596     auto dec_dst_iter_c_md = dnnl::memory::desc({dec_dst_iter_c_dims},
597             dnnl::memory::data_type::f32, dnnl::memory::format_tag::ldnc);
598     //[dec mem desc]
599     ///
600     /// Decoder : Create memory
601     /// @snippet cpu_rnn_inference_f32.cpp create dec memory
602     ///
603     //[create dec memory]
604     auto user_dec_wei_layer_memory = dnnl::memory(
605             user_dec_wei_layer_md, cpu_engine, user_dec_wei_layer.data());
606     auto user_dec_wei_iter_memory = dnnl::memory(
607             user_dec_wei_iter_md, cpu_engine, user_dec_wei_iter.data());
608     auto user_dec_bias_memory
609             = dnnl::memory(user_dec_bias_md, cpu_engine, user_dec_bias.data());
610     auto user_dec_dst_layer_memory
611             = dnnl::memory(dec_dst_layer_md, cpu_engine, user_dec_dst.data());
612     auto dec_src_layer_memory = dnnl::memory(dec_src_layer_md, cpu_engine);
613     auto dec_dst_iter_c_memory = dnnl::memory(dec_dst_iter_c_md, cpu_engine);
614     //[create dec memory]
615 
616     auto dec_wei_layer_md = dnnl::memory::desc({user_dec_wei_layer_dims},
617             dnnl::memory::data_type::f32, dnnl::memory::format_tag::any);
618     auto dec_wei_iter_md = dnnl::memory::desc({user_dec_wei_iter_dims},
619             dnnl::memory::data_type::f32, dnnl::memory::format_tag::any);
620 
621     // As mentioned above, we create a view without context out of the
622     // memory with context.
623     ///
624     /// Decoder : As mentioned above, we create a view without context out of the memory with context.
625     /// @snippet cpu_rnn_inference_f32.cpp create noctx mem
626     ///
627     //[create noctx mem]
628     auto dec_dst_iter_memory = dnnl::memory(dec_dst_iter_md, cpu_engine);
629     auto dec_dst_iter_noctx_md = dec_dst_iter_md.submemory_desc(
630             dec_dst_iter_noctx_dims, {0, 0, 0, 0, 0});
631     //[create noctx mem]
632 
633     // TODO: add support for residual connections
634     // should it be a set residual in op_desc or a field to set manually?
635     // should be an integer to specify at which layer to start
636     ///
637     /// Decoder : Create RNN decoder cell
638     /// @snippet cpu_rnn_inference_f32.cpp create dec rnn
639     ///
640     //[create dec rnn]
641     lstm_forward::desc dec_ctx_desc(prop_kind::forward_inference,
642             rnn_direction::unidirectional_left2right, dec_src_layer_md,
643             dec_dst_iter_md, dec_dst_iter_c_md, dec_wei_layer_md,
644             dec_wei_iter_md, user_dec_bias_md, dec_dst_layer_md,
645             dec_dst_iter_noctx_md, dec_dst_iter_c_md);
646     auto dec_ctx_prim_desc
647             = dnnl::lstm_forward::primitive_desc(dec_ctx_desc, cpu_engine);
648     //[create dec rnn]
649 
650     ///
651     /// Decoder : reorder weight memory
652     /// @snippet cpu_rnn_inference_f32.cpp reorder weight memory
653     ///
654     //[reorder weight memory]
655     auto dec_wei_layer_memory
656             = memory(dec_ctx_prim_desc.weights_layer_desc(), cpu_engine);
657     auto dec_wei_layer_reorder_pd = reorder::primitive_desc(
658             user_dec_wei_layer_memory, dec_wei_layer_memory);
659     reorder(dec_wei_layer_reorder_pd)
660             .execute(s, user_dec_wei_layer_memory, dec_wei_layer_memory);
661 
662     auto dec_wei_iter_memory
663             = memory(dec_ctx_prim_desc.weights_iter_desc(), cpu_engine);
664     auto dec_wei_iter_reorder_pd = reorder::primitive_desc(
665             user_dec_wei_iter_memory, dec_wei_iter_memory);
666     reorder(dec_wei_iter_reorder_pd)
667             .execute(s, user_dec_wei_iter_memory, dec_wei_iter_memory);
668     //[reorder weight memory]
669 
670     ///
671     /// Decoder : add the rnn primitive with related arguments into decoder_net
672     /// @snippet cpu_rnn_inference_f32.cpp push rnn to decoder net
673     ///
674     //[push rnn to decoder net]
675     // TODO: add a reorder when they will be available
676     decoder_net.push_back(lstm_forward(dec_ctx_prim_desc));
677     decoder_net_args.push_back({{DNNL_ARG_SRC_LAYER, dec_src_layer_memory},
678             {DNNL_ARG_SRC_ITER, dec_dst_iter_memory},
679             {DNNL_ARG_SRC_ITER_C, dec_dst_iter_c_memory},
680             {DNNL_ARG_WEIGHTS_LAYER, dec_wei_layer_memory},
681             {DNNL_ARG_WEIGHTS_ITER, dec_wei_iter_memory},
682             {DNNL_ARG_BIAS, user_dec_bias_memory},
683             {DNNL_ARG_DST_LAYER, user_dec_dst_layer_memory},
684             {DNNL_ARG_DST_ITER, dec_dst_iter_memory},
685             {DNNL_ARG_DST_ITER_C, dec_dst_iter_c_memory}});
686     //[push rnn to decoder net]
687     // allocating temporary buffer for attention mechanism
688     std::vector<float> weighted_annotations(
689             src_seq_length_max * batch * feature_size, 1.0f);
690 
691     ///
692     /// **Execution**
693     ///
694     auto execute = [&]() {
695         assert(encoder_net.size() == encoder_net_args.size()
696                 && "something is missing");
697         ///
698         /// run encoder (1 stream)
699         /// @snippet cpu_rnn_inference_f32.cpp run enc
700         ///
701         //[run enc]
702         for (size_t p = 0; p < encoder_net.size(); ++p)
703             encoder_net.at(p).execute(s, encoder_net_args.at(p));
704         //[run enc]
705 
706         ///
707         /// we compute the weighted annotations once before the decoder
708         /// @snippet cpu_rnn_inference_f32.cpp weight ano
709         ///
710         //[weight ano]
711         compute_weighted_annotations(weighted_annotations.data(),
712                 src_seq_length_max, batch, feature_size,
713                 user_weights_annotation.data(),
714                 (float *)enc_dst_layer_memory.get_data_handle());
715         //[weight ano]
716 
717         ///
718         /// We initialize src_layer to the embedding of the end of
719         /// sequence character, which are assumed to be 0 here
720         /// @snippet cpu_rnn_inference_f32.cpp init src_layer
721         ///
722         //[init src_layer]
723         memset(dec_src_layer_memory.get_data_handle(), 0,
724                 dec_src_layer_memory.get_desc().get_size());
725         //[init src_layer]
726         ///
727         /// From now on, src points to the output of the last iteration
728         ///
729         for (dim_t i = 0; i < tgt_seq_length_max; i++) {
730             float *src_att_layer_handle
731                     = (float *)dec_src_layer_memory.get_data_handle();
732             float *src_att_iter_handle
733                     = (float *)dec_dst_iter_memory.get_data_handle();
734 
735             ///
736             /// Compute attention context vector into the first layer src_iter
737             /// @snippet cpu_rnn_inference_f32.cpp att ctx
738             ///
739             //[att ctx]
740             compute_attention(src_att_iter_handle, src_seq_length_max, batch,
741                     feature_size, user_weights_attention_src_layer.data(),
742                     src_att_layer_handle,
743                     (float *)enc_bidir_dst_layer_memory.get_data_handle(),
744                     weighted_annotations.data(),
745                     user_weights_alignments.data());
746             //[att ctx]
747 
748             ///
749             /// copy the context vectors to all layers of src_iter
750             /// @snippet cpu_rnn_inference_f32.cpp cp ctx
751             ///
752             //[cp ctx]
753             copy_context(
754                     src_att_iter_handle, dec_n_layers, batch, feature_size);
755             //[cp ctx]
756 
757             assert(decoder_net.size() == decoder_net_args.size()
758                     && "something is missing");
759             ///
760             /// run the decoder iteration
761             /// @snippet cpu_rnn_inference_f32.cpp run dec iter
762             ///
763             //[run dec iter]
764             for (size_t p = 0; p < decoder_net.size(); ++p)
765                 decoder_net.at(p).execute(s, decoder_net_args.at(p));
766             //[run dec iter]
767 
768             ///
769             /// Move the handle on the src/dst layer to the next iteration
770             /// @snippet cpu_rnn_inference_f32.cpp set handle
771             ///
772             //[set handle]
773             auto dst_layer_handle
774                     = (float *)user_dec_dst_layer_memory.get_data_handle();
775             dec_src_layer_memory.set_data_handle(dst_layer_handle);
776             user_dec_dst_layer_memory.set_data_handle(
777                     dst_layer_handle + batch * feature_size);
778             //[set handle]
779         }
780     };
781     /// @page cpu_rnn_inference_f32_cpp
782     ///
783     std::cout << "Parameters:" << std::endl
784               << " batch = " << batch << std::endl
785               << " feature size = " << feature_size << std::endl
786               << " maximum source sequence length = " << src_seq_length_max
787               << std::endl
788               << " maximum target sequence length = " << tgt_seq_length_max
789               << std::endl
790               << " number of layers of the bidirectional encoder = "
791               << enc_bidir_n_layers << std::endl
792               << " number of layers of the unidirectional encoder = "
793               << enc_unidir_n_layers << std::endl
794               << " number of layers of the decoder = " << dec_n_layers
795               << std::endl;
796 
797     execute();
798     s.wait();
799 }
800 
main(int argc,char ** argv)801 int main(int argc, char **argv) {
802     return handle_example_errors({engine::kind::cpu}, simple_net);
803 }
804