1 /*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include <stdlib.h>
18 
19 #include <cmath>
20 
21 #include "rnn/rnn.hpp"
22 #include "rnn/rnn_aux.hpp"
23 
24 #include "rnn/cells.hpp"
25 
26 namespace rnn {
27 
prepare_ws_fwd(const prb_t & prb,std::vector<float> & ws_fwd_buffer,AOC<float> & ws_src_layer,AOC<float> & ws_src_iter,AOC<float> & ws_src_iter_c,AOC<float> & ws_gates,AOC<float> & ws_ht)28 void prepare_ws_fwd(const prb_t &prb, std::vector<float> &ws_fwd_buffer,
29         AOC<float> &ws_src_layer, AOC<float> &ws_src_iter,
30         AOC<float> &ws_src_iter_c, AOC<float> &ws_gates, AOC<float> &ws_ht) {
31     bool is_lstm = prb.alg == VANILLA_LSTM;
32     bool is_lstmp = prb.is_lstm_projection();
33 
34     ws_src_layer = AOC<float>(nullptr, prb.n_layer + 2, prb.n_dir(),
35             prb.n_iter + 2, prb.mb, prb.wc);
36     ws_src_iter = AOC<float>(nullptr, prb.n_layer + 2, prb.n_dir(),
37             prb.n_iter + 2, prb.mb, prb.wc);
38     ws_src_iter_c = AOC<float>(nullptr, prb.n_layer + 2, prb.n_dir(),
39             prb.n_iter + 2, prb.mb, prb.wc);
40     ws_gates = AOC<float>(nullptr, prb.n_layer, prb.n_dir(), prb.n_iter, prb.mb,
41             prb.n_gates(), prb.dhc);
42     ws_ht = AOC<float>(
43             nullptr, prb.n_layer, prb.n_dir(), prb.n_iter, prb.mb, prb.wc);
44 
45     int64_t size = ws_src_layer.nelems() + is_lstm * ws_src_iter_c.nelems()
46             + ws_gates.nelems() + is_lstmp * ws_ht.nelems();
47     ws_fwd_buffer.resize(size);
48 
49     float *ptr = ws_fwd_buffer.data();
50     ws_src_layer.set_base_ptr(ptr);
51     ws_src_iter.set_base_ptr(ptr);
52 
53     ptr += ws_src_iter.nelems();
54     ws_src_iter_c.set_base_ptr(ptr);
55 
56     ptr += is_lstm * ws_src_iter_c.nelems();
57     ws_gates.set_base_ptr(ptr);
58 
59     ptr += is_lstmp * ws_gates.nelems();
60     ws_ht.set_base_ptr(ptr);
61 }
62 
63 /******************************************************************************/
64 /******************************* Copy Routines ********************************/
65 /******************************************************************************/
prepare_projection_compensation(const prb_t & prb,float * weights_projection_compensation_,const float * weights_projection_)66 void prepare_projection_compensation(const prb_t &prb,
67         float *weights_projection_compensation_,
68         const float *weights_projection_) {
69     AOC<float> weights_projection_compensation(weights_projection_compensation_,
70             prb.n_layer, prb.n_dir(), prb.dic);
71     AOC<const float> weights_projection(
72             weights_projection_, prb.n_layer, prb.n_dir(), prb.dhc, prb.dic);
73     for (int layer = 0; layer < prb.n_layer; ++layer)
74         for (int dir = 0; dir < prb.n_dir(); ++dir)
75             for (int dic = 0; dic < prb.dic; ++dic) {
76                 float weights_compensation = 0;
77                 for (int dhc = 0; dhc < prb.dhc; ++dhc)
78                     weights_compensation
79                             += weights_projection(layer, dir, dhc, dic);
80                 weights_projection_compensation(layer, dir, dic)
81                         = weights_compensation;
82             }
83 }
84 
prepare_bias(const prb_t & prb,float * bias_with_compensation_,const float * bias_,const float * weights_layer_,const float * weights_iter_)85 void prepare_bias(const prb_t &prb, float *bias_with_compensation_,
86         const float *bias_, const float *weights_layer_,
87         const float *weights_iter_) {
88     AOC<const float> weights_layer(weights_layer_, prb.n_layer, prb.n_dir(),
89             prb.slc, prb.n_gates(), prb.dhc);
90     AOC<const float> weights_iter(weights_iter_, prb.n_layer, prb.n_dir(),
91             prb.sic, prb.n_gates(), prb.dhc);
92 
93     AOC<const float> bias(
94             bias_, prb.n_layer, prb.n_dir(), prb.n_gates(), prb.dhc);
95     AOC<float> bias_with_compensation(bias_with_compensation_, prb.n_layer,
96             prb.n_dir(), prb.n_gates(), prb.dhc);
97 
98     for (int layer = 0; layer < prb.n_layer; ++layer)
99         for (int dir = 0; dir < prb.n_dir(); ++dir)
100             for (int gate = 0; gate < prb.n_gates(); ++gate)
101                 for (int dhc = 0; dhc < prb.dhc; ++dhc) {
102                     float weights_compensation = 0;
103                     for (int sic = 0; sic < prb.sic; ++sic)
104                         weights_compensation
105                                 += weights_iter(layer, dir, sic, gate, dhc);
106                     for (int slc = 0; slc < prb.slc; ++slc)
107                         weights_compensation
108                                 += weights_layer(layer, dir, slc, gate, dhc);
109 
110                     float scale = prb.data_scale
111                             * prb.get_wei_scale(gate * prb.dhc + dhc);
112                     bias_with_compensation(layer, dir, gate, dhc)
113                             = bias(layer, dir, gate, dhc)
114                             - weights_compensation * prb.data_shift / scale;
115                 }
116 }
117 
copy_init_fwd(const prb_t & prb,const AOC<float> & ws_src_layer,const AOC<float> & ws_src_iter,const AOC<float> & ws_src_iter_c,const float * src_layer_,const float * src_iter_,const float * src_iter_c_,rnn_iter_direction_t iter_dir,rnn_layer_direction_t lay_dir,int64_t dir_val)118 void copy_init_fwd(const prb_t &prb, const AOC<float> &ws_src_layer,
119         const AOC<float> &ws_src_iter, const AOC<float> &ws_src_iter_c,
120         const float *src_layer_, const float *src_iter_,
121         const float *src_iter_c_, rnn_iter_direction_t iter_dir,
122         rnn_layer_direction_t lay_dir, int64_t dir_val) {
123     AOC<const float> src_layer(src_layer_, prb.n_iter, prb.mb * prb.slc);
124     AOC<const float> src_iter(
125             src_iter_, prb.n_layer, prb.n_dir(), prb.mb * prb.sic);
126     AOC<const float> src_iter_c(
127             src_iter_c_, prb.n_layer, prb.n_dir(), prb.mb * prb.dhc);
128 
129     int64_t lay_dest = (lay_dir == bottom2top) ? 0 : prb.n_layer + 1;
130     int64_t it_dest = (iter_dir == left2right) ? 0 : prb.n_iter + 1;
131 
132     // Copy src_layer
133     for (int64_t it = 0; it < prb.n_iter; it++) {
134         copy(prb.mb, prb.slc, prb.slc, prb.wc, &src_layer(it, 0),
135                 &ws_src_layer(lay_dest, dir_val, it + 1, 0, 0));
136         if (prb.is_int8())
137             data_q10n(prb.mb, prb.slc, prb.wc,
138                     &ws_src_layer(lay_dest, dir_val, it + 1, 0, 0),
139                     prb.data_scale, prb.data_shift);
140     }
141 
142     // Copy src_iter (and src_iter_c)
143     for (int64_t lay = 0; lay < prb.n_layer; lay++) {
144         copy(prb.mb, prb.sic, prb.sic, prb.wc, &src_iter(lay, dir_val, 0),
145                 &ws_src_iter(lay + 1, dir_val, it_dest, 0, 0));
146         if (prb.is_int8())
147             data_q10n(prb.mb, prb.sic, prb.wc,
148                     &ws_src_iter(lay + 1, dir_val, it_dest, 0, 0),
149                     prb.data_scale, prb.data_shift);
150 
151         if (prb.alg == VANILLA_LSTM)
152             copy(prb.mb, prb.dhc, prb.dhc, prb.wc, &src_iter_c(lay, dir_val, 0),
153                     &ws_src_iter_c(lay + 1, dir_val, it_dest, 0, 0));
154     }
155 }
156 
copy_res_fwd(const prb_t & prb,float * dst_layer_,float * dst_iter_,float * dst_iter_c_,const AOC<const float> & ws_src_layer,const AOC<const float> & ws_src_iter,const AOC<const float> & ws_src_iter_c,rnn_iter_direction_t iter_dir,rnn_layer_direction_t lay_dir,int64_t dir_val,rnn_action_t action)157 void copy_res_fwd(const prb_t &prb, float *dst_layer_, float *dst_iter_,
158         float *dst_iter_c_, const AOC<const float> &ws_src_layer,
159         const AOC<const float> &ws_src_iter,
160         const AOC<const float> &ws_src_iter_c, rnn_iter_direction_t iter_dir,
161         rnn_layer_direction_t lay_dir, int64_t dir_val, rnn_action_t action) {
162     AOC<float> dst_iter(dst_iter_, prb.n_layer, prb.n_dir(), prb.mb, prb.dic);
163     AOC<float> dst_iter_c(
164             dst_iter_c_, prb.n_layer, prb.n_dir(), prb.mb, prb.dhc);
165     AOC<float> dst_layer(dst_layer_, prb.n_iter, prb.mb, prb.dlc(PRIMITIVE));
166     const bool is_layer_deq = (prb.is_u8() && prb.cfg[DST_LAYER].dt != dnnl_u8)
167             || (prb.is_s8() && prb.cfg[DST_LAYER].dt != dnnl_s8);
168     const bool is_iter_deq = (prb.is_u8() && prb.cfg[DST_ITER].dt != dnnl_u8)
169             || (prb.is_s8() && prb.cfg[DST_ITER].dt != dnnl_s8);
170     // Copy dst_layer
171     for (int64_t it = 0; it < prb.n_iter; it++) {
172         for (int64_t nb = 0; nb < prb.mb; nb++) {
173             auto from = &ws_src_layer(prb.n_layer, dir_val, it + 1, nb, 0);
174             auto to = &dst_layer(
175                     it, nb, action == action_concat ? prb.dlc(CELL) : 0);
176             copy(1, prb.dlc(CELL), prb.wc, prb.dlc(PRIMITIVE), from, to, action,
177                     prb.is_int8());
178 
179             if (is_layer_deq) {
180                 float data_shift = prb.data_shift;
181                 bool do_deq10n = true;
182 
183                 if (prb.direction == dnnl_bidirectional_sum) {
184                     // In `bidir_sum` case, we need to dequantize data only
185                     // after the final summation. Also, since we sum two shifted
186                     // tensors, we need to enlarge the shift by 2x.
187                     do_deq10n = action == action_sum;
188                     data_shift *= 2;
189                 }
190 
191                 if (do_deq10n)
192                     data_deq10n(1, prb.dlc(CELL), prb.dlc(PRIMITIVE), to,
193                             prb.data_scale, data_shift);
194             }
195         }
196     }
197 
198     int64_t it_source = (iter_dir == left2right) ? prb.n_iter : 1;
199 
200     // Copy dst_iter (and dst_iter_c)
201     for (int64_t lay = 0; lay < prb.n_layer; lay++) {
202         if (prb.alg == VANILLA_LSTM) {
203             copy(prb.mb, prb.dhc, prb.wc, prb.dhc,
204                     &ws_src_iter_c(lay + 1, dir_val, it_source, 0, 0),
205                     &dst_iter_c(lay, dir_val, 0, 0));
206         }
207 
208         copy(prb.mb, prb.dic, prb.wc, prb.dic,
209                 &ws_src_iter(lay + 1, dir_val, it_source, 0, 0),
210                 &dst_iter(lay, dir_val, 0, 0));
211         if (is_iter_deq)
212             data_deq10n(prb.mb, prb.dic, prb.dic, &dst_iter(lay, dir_val, 0, 0),
213                     prb.data_scale, prb.data_shift);
214     }
215 }
216 
217 /******************************************************************************/
218 /*************************** Computation Routines *****************************/
219 /******************************************************************************/
220 
rnn_cell_fwd(const prb_t & prb,float * dst_layer,float * dst_iter,float * dst_iter_c,float * gates,float * ht,const float * weights_layer,const float * weights_iter,const float * weights_peephole,const float * weights_projection,const float * weights_projection_compensation,const float * bias,const float * src_layer,const float * src_iter,const float * src_iter_c,float * cell_scratchpad_)221 void rnn_cell_fwd(const prb_t &prb, float *dst_layer, float *dst_iter,
222         float *dst_iter_c, float *gates, float *ht, const float *weights_layer,
223         const float *weights_iter, const float *weights_peephole,
224         const float *weights_projection,
225         const float *weights_projection_compensation, const float *bias,
226         const float *src_layer, const float *src_iter, const float *src_iter_c,
227         float *cell_scratchpad_) {
228     if (prb.alg != VANILLA_LSTM) assert(dst_layer == dst_iter);
229 
230     switch (prb.alg) {
231         case VANILLA_GRU:
232             gru_fwd(prb, dst_layer, gates, weights_layer, weights_iter, bias,
233                     src_layer, src_iter);
234             break;
235         case LBR_GRU:
236             lbr_gru_fwd(prb, dst_layer, gates, weights_layer, weights_iter,
237                     bias, src_layer, src_iter, cell_scratchpad_);
238             break;
239         case VANILLA_LSTM:
240             lstm_fwd(prb, dst_layer, dst_iter, dst_iter_c, gates, ht,
241                     weights_layer, weights_iter, weights_peephole,
242                     weights_projection, weights_projection_compensation, bias,
243                     src_layer, src_iter, src_iter_c);
244             break;
245         case VANILLA_RNN:
246             rnn_fwd(prb, dst_layer, gates, weights_layer, weights_iter, bias,
247                     src_layer, src_iter);
248             break;
249         default: break;
250     }
251 }
252 
rnn_linear_fwd(const prb_t & prb,const float * src_layer_,const float * src_iter_,const float * src_iter_c_,const float * weights_layer_,const float * weights_iter_,const float * weights_peephole_,const float * weights_projection_,const float * bias_,float * dst_layer_,float * dst_iter_,float * dst_iter_c_,const AOC<float> & ws_src_layer,const AOC<float> & ws_src_iter,const AOC<float> & ws_src_iter_c,const AOC<float> & ws_gates,const AOC<float> & ws_ht)253 void rnn_linear_fwd(const prb_t &prb, const float *src_layer_,
254         const float *src_iter_, const float *src_iter_c_,
255         const float *weights_layer_, const float *weights_iter_,
256         const float *weights_peephole_, const float *weights_projection_,
257         const float *bias_, float *dst_layer_, float *dst_iter_,
258         float *dst_iter_c_, const AOC<float> &ws_src_layer,
259         const AOC<float> &ws_src_iter, const AOC<float> &ws_src_iter_c,
260         const AOC<float> &ws_gates, const AOC<float> &ws_ht) {
261     bool is_lbr = prb.alg == LBR_GRU;
262 
263     float *bias_with_compensation = nullptr;
264     float *weights_projection_compensation_ = nullptr;
265     if (prb.is_int8()) {
266         bias_with_compensation = new float[prb.n_layer * prb.n_dir()
267                 * (prb.n_gates() + is_lbr) * prb.dhc];
268         prepare_bias(prb, bias_with_compensation, bias_, weights_layer_,
269                 weights_iter_);
270         bias_ = bias_with_compensation;
271         if (prb.is_lstm_projection()) {
272             weights_projection_compensation_
273                     = new float[prb.n_layer * prb.n_dir() * prb.dic];
274             prepare_projection_compensation(
275                     prb, weights_projection_compensation_, weights_projection_);
276         }
277     }
278 
279     AOC<const float> weights_peephole(
280             weights_peephole_, prb.n_layer, prb.n_dir(), 3 * prb.dhc);
281     AOC<const float> weights_projection(
282             weights_projection_, prb.n_layer, prb.n_dir(), prb.dhc * prb.dic);
283     AOC<const float> weights_projection_compensation(
284             weights_projection_compensation_, prb.n_layer, prb.n_dir(),
285             prb.dic);
286     AOC<const float> bias(bias_, prb.n_layer, prb.n_dir(),
287             (prb.n_gates() + is_lbr) * prb.dhc);
288     AOC<const float> weights_layer(weights_layer_, prb.n_layer, prb.n_dir(),
289             prb.n_gates() * prb.dhc, prb.slc);
290     AOC<const float> weights_iter(weights_iter_, prb.n_layer, prb.n_dir(),
291             prb.n_gates() * prb.dhc, prb.sic);
292 
293     int64_t cell_scratchpad_size = is_lbr * prb.mb * prb.n_gates() * prb.dhc;
294     float *cell_scratchpad_ = new float[cell_scratchpad_size];
295     for (int i = 0; i < cell_scratchpad_size; i++) {
296         cell_scratchpad_[i] = NAN;
297     }
298 
299     auto process_direction = [&](rnn_iter_direction_t iter_dir,
300                                      rnn_layer_direction_t lay_dir,
301                                      int64_t dir_val, rnn_action_t action) {
302         // we first need to copy the initial src_layer and src_iter{,_c} into
303         // ws to simplify the logic of the code
304         BENCHDNN_PRINT(80,
305                 "rnn_linear_fwd: call copy_init dir_val = " IFMT "\n", dir_val);
306         copy_init_fwd(prb, ws_src_layer, ws_src_iter, ws_src_iter_c, src_layer_,
307                 src_iter_, src_iter_c_, iter_dir, lay_dir, dir_val);
308 
309         // We run the grid of computation
310         for (int64_t il = 0; il < prb.n_layer; il++) {
311             for (int64_t it = 0; it < prb.n_iter; it++) {
312                 BENCHDNN_PRINT(80,
313                         "==== layer = " IFMT " iter = " IFMT " ===\n", il, it);
314                 int64_t iter
315                         = (iter_dir == left2right) ? it + 1 : prb.n_iter - it;
316                 int64_t prev_iter
317                         = (iter_dir == left2right) ? iter - 1 : iter + 1;
318                 int64_t lay = il + 1;
319                 rnn_cell_fwd(prb, &ws_src_layer(lay, dir_val, iter, 0, 0),
320                         &ws_src_iter(lay, dir_val, iter, 0, 0),
321                         &ws_src_iter_c(lay, dir_val, iter, 0, 0),
322                         &ws_gates(lay - 1, dir_val, iter - 1, 0, 0, 0),
323                         &ws_ht(lay - 1, dir_val, iter - 1, 0, 0),
324                         &weights_layer(lay - 1, dir_val, 0, 0),
325                         &weights_iter(lay - 1, dir_val, 0, 0),
326                         &weights_peephole(lay - 1, dir_val, 0),
327                         &weights_projection(lay - 1, dir_val, 0),
328                         &weights_projection_compensation(lay - 1, dir_val, 0),
329                         &bias(lay - 1, dir_val, 0),
330                         &ws_src_layer(lay - 1, dir_val, iter, 0, 0),
331                         &ws_src_iter(lay, dir_val, prev_iter, 0, 0),
332                         &ws_src_iter_c(lay, dir_val, prev_iter, 0, 0),
333                         cell_scratchpad_);
334             }
335         }
336 
337         // Finally we copy the results to the result buffers
338         copy_res_fwd(prb, dst_layer_, dst_iter_, dst_iter_c_, ws_src_layer,
339                 ws_src_iter, ws_src_iter_c, iter_dir, lay_dir, dir_val, action);
340     };
341 
342     switch (prb.direction) {
343         case dnnl_unidirectional_left2right:
344             process_direction(left2right, bottom2top, 0, action_copy);
345             break;
346         case dnnl_unidirectional_right2left:
347             process_direction(right2left, bottom2top, 0, action_copy);
348             break;
349         case dnnl_bidirectional_sum:
350             process_direction(left2right, bottom2top, 0, action_copy);
351             process_direction(right2left, bottom2top, 1, action_sum);
352             break;
353         case dnnl_bidirectional_concat:
354             process_direction(left2right, bottom2top, 0, action_copy);
355             process_direction(right2left, bottom2top, 1, action_concat);
356             break;
357         default: assert(!"unknown direction"); break;
358     }
359 
360     delete[] cell_scratchpad_;
361     delete[] bias_with_compensation;
362     delete[] weights_projection_compensation_;
363 }
364 
compute_ref_fwd(const prb_t & prb,dnn_mem_t & src_layer_m,dnn_mem_t & src_iter_m,dnn_mem_t & src_iter_c_m,dnn_mem_t & weights_src_layer_m,dnn_mem_t & weights_src_iter_m,dnn_mem_t & weights_peephole_m,dnn_mem_t & weights_projection_m,dnn_mem_t & bias_m,dnn_mem_t & dst_layer_m,dnn_mem_t & dst_iter_m,dnn_mem_t & dst_iter_c_m)365 void compute_ref_fwd(const prb_t &prb, dnn_mem_t &src_layer_m,
366         dnn_mem_t &src_iter_m, dnn_mem_t &src_iter_c_m,
367         dnn_mem_t &weights_src_layer_m, dnn_mem_t &weights_src_iter_m,
368         dnn_mem_t &weights_peephole_m, dnn_mem_t &weights_projection_m,
369         dnn_mem_t &bias_m, dnn_mem_t &dst_layer_m, dnn_mem_t &dst_iter_m,
370         dnn_mem_t &dst_iter_c_m) {
371     std::vector<float> ws_fwd_buffer;
372     AOC<float> ws_src_layer, ws_src_iter, ws_src_iter_c, ws_gates, ws_ht;
373     prepare_ws_fwd(prb, ws_fwd_buffer, ws_src_layer, ws_src_iter, ws_src_iter_c,
374             ws_gates, ws_ht);
375 
376     rnn_linear_fwd(prb, (float *)src_layer_m, (float *)src_iter_m,
377             (float *)src_iter_c_m, (float *)weights_src_layer_m,
378             (float *)weights_src_iter_m, (float *)weights_peephole_m,
379             (float *)weights_projection_m, (float *)bias_m,
380             (float *)dst_layer_m, (float *)dst_iter_m, (float *)dst_iter_c_m,
381             ws_src_layer, ws_src_iter, ws_src_iter_c, ws_gates, ws_ht);
382 }
383 
384 } // namespace rnn
385