1 /*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <stdlib.h>
18
19 #include <cmath>
20
21 #include "rnn/rnn.hpp"
22 #include "rnn/rnn_aux.hpp"
23
24 #include "rnn/cells.hpp"
25
26 namespace rnn {
27
prepare_ws_fwd(const prb_t & prb,std::vector<float> & ws_fwd_buffer,AOC<float> & ws_src_layer,AOC<float> & ws_src_iter,AOC<float> & ws_src_iter_c,AOC<float> & ws_gates,AOC<float> & ws_ht)28 void prepare_ws_fwd(const prb_t &prb, std::vector<float> &ws_fwd_buffer,
29 AOC<float> &ws_src_layer, AOC<float> &ws_src_iter,
30 AOC<float> &ws_src_iter_c, AOC<float> &ws_gates, AOC<float> &ws_ht) {
31 bool is_lstm = prb.alg == VANILLA_LSTM;
32 bool is_lstmp = prb.is_lstm_projection();
33
34 ws_src_layer = AOC<float>(nullptr, prb.n_layer + 2, prb.n_dir(),
35 prb.n_iter + 2, prb.mb, prb.wc);
36 ws_src_iter = AOC<float>(nullptr, prb.n_layer + 2, prb.n_dir(),
37 prb.n_iter + 2, prb.mb, prb.wc);
38 ws_src_iter_c = AOC<float>(nullptr, prb.n_layer + 2, prb.n_dir(),
39 prb.n_iter + 2, prb.mb, prb.wc);
40 ws_gates = AOC<float>(nullptr, prb.n_layer, prb.n_dir(), prb.n_iter, prb.mb,
41 prb.n_gates(), prb.dhc);
42 ws_ht = AOC<float>(
43 nullptr, prb.n_layer, prb.n_dir(), prb.n_iter, prb.mb, prb.wc);
44
45 int64_t size = ws_src_layer.nelems() + is_lstm * ws_src_iter_c.nelems()
46 + ws_gates.nelems() + is_lstmp * ws_ht.nelems();
47 ws_fwd_buffer.resize(size);
48
49 float *ptr = ws_fwd_buffer.data();
50 ws_src_layer.set_base_ptr(ptr);
51 ws_src_iter.set_base_ptr(ptr);
52
53 ptr += ws_src_iter.nelems();
54 ws_src_iter_c.set_base_ptr(ptr);
55
56 ptr += is_lstm * ws_src_iter_c.nelems();
57 ws_gates.set_base_ptr(ptr);
58
59 ptr += is_lstmp * ws_gates.nelems();
60 ws_ht.set_base_ptr(ptr);
61 }
62
63 /******************************************************************************/
64 /******************************* Copy Routines ********************************/
65 /******************************************************************************/
prepare_projection_compensation(const prb_t & prb,float * weights_projection_compensation_,const float * weights_projection_)66 void prepare_projection_compensation(const prb_t &prb,
67 float *weights_projection_compensation_,
68 const float *weights_projection_) {
69 AOC<float> weights_projection_compensation(weights_projection_compensation_,
70 prb.n_layer, prb.n_dir(), prb.dic);
71 AOC<const float> weights_projection(
72 weights_projection_, prb.n_layer, prb.n_dir(), prb.dhc, prb.dic);
73 for (int layer = 0; layer < prb.n_layer; ++layer)
74 for (int dir = 0; dir < prb.n_dir(); ++dir)
75 for (int dic = 0; dic < prb.dic; ++dic) {
76 float weights_compensation = 0;
77 for (int dhc = 0; dhc < prb.dhc; ++dhc)
78 weights_compensation
79 += weights_projection(layer, dir, dhc, dic);
80 weights_projection_compensation(layer, dir, dic)
81 = weights_compensation;
82 }
83 }
84
prepare_bias(const prb_t & prb,float * bias_with_compensation_,const float * bias_,const float * weights_layer_,const float * weights_iter_)85 void prepare_bias(const prb_t &prb, float *bias_with_compensation_,
86 const float *bias_, const float *weights_layer_,
87 const float *weights_iter_) {
88 AOC<const float> weights_layer(weights_layer_, prb.n_layer, prb.n_dir(),
89 prb.slc, prb.n_gates(), prb.dhc);
90 AOC<const float> weights_iter(weights_iter_, prb.n_layer, prb.n_dir(),
91 prb.sic, prb.n_gates(), prb.dhc);
92
93 AOC<const float> bias(
94 bias_, prb.n_layer, prb.n_dir(), prb.n_gates(), prb.dhc);
95 AOC<float> bias_with_compensation(bias_with_compensation_, prb.n_layer,
96 prb.n_dir(), prb.n_gates(), prb.dhc);
97
98 for (int layer = 0; layer < prb.n_layer; ++layer)
99 for (int dir = 0; dir < prb.n_dir(); ++dir)
100 for (int gate = 0; gate < prb.n_gates(); ++gate)
101 for (int dhc = 0; dhc < prb.dhc; ++dhc) {
102 float weights_compensation = 0;
103 for (int sic = 0; sic < prb.sic; ++sic)
104 weights_compensation
105 += weights_iter(layer, dir, sic, gate, dhc);
106 for (int slc = 0; slc < prb.slc; ++slc)
107 weights_compensation
108 += weights_layer(layer, dir, slc, gate, dhc);
109
110 float scale = prb.data_scale
111 * prb.get_wei_scale(gate * prb.dhc + dhc);
112 bias_with_compensation(layer, dir, gate, dhc)
113 = bias(layer, dir, gate, dhc)
114 - weights_compensation * prb.data_shift / scale;
115 }
116 }
117
copy_init_fwd(const prb_t & prb,const AOC<float> & ws_src_layer,const AOC<float> & ws_src_iter,const AOC<float> & ws_src_iter_c,const float * src_layer_,const float * src_iter_,const float * src_iter_c_,rnn_iter_direction_t iter_dir,rnn_layer_direction_t lay_dir,int64_t dir_val)118 void copy_init_fwd(const prb_t &prb, const AOC<float> &ws_src_layer,
119 const AOC<float> &ws_src_iter, const AOC<float> &ws_src_iter_c,
120 const float *src_layer_, const float *src_iter_,
121 const float *src_iter_c_, rnn_iter_direction_t iter_dir,
122 rnn_layer_direction_t lay_dir, int64_t dir_val) {
123 AOC<const float> src_layer(src_layer_, prb.n_iter, prb.mb * prb.slc);
124 AOC<const float> src_iter(
125 src_iter_, prb.n_layer, prb.n_dir(), prb.mb * prb.sic);
126 AOC<const float> src_iter_c(
127 src_iter_c_, prb.n_layer, prb.n_dir(), prb.mb * prb.dhc);
128
129 int64_t lay_dest = (lay_dir == bottom2top) ? 0 : prb.n_layer + 1;
130 int64_t it_dest = (iter_dir == left2right) ? 0 : prb.n_iter + 1;
131
132 // Copy src_layer
133 for (int64_t it = 0; it < prb.n_iter; it++) {
134 copy(prb.mb, prb.slc, prb.slc, prb.wc, &src_layer(it, 0),
135 &ws_src_layer(lay_dest, dir_val, it + 1, 0, 0));
136 if (prb.is_int8())
137 data_q10n(prb.mb, prb.slc, prb.wc,
138 &ws_src_layer(lay_dest, dir_val, it + 1, 0, 0),
139 prb.data_scale, prb.data_shift);
140 }
141
142 // Copy src_iter (and src_iter_c)
143 for (int64_t lay = 0; lay < prb.n_layer; lay++) {
144 copy(prb.mb, prb.sic, prb.sic, prb.wc, &src_iter(lay, dir_val, 0),
145 &ws_src_iter(lay + 1, dir_val, it_dest, 0, 0));
146 if (prb.is_int8())
147 data_q10n(prb.mb, prb.sic, prb.wc,
148 &ws_src_iter(lay + 1, dir_val, it_dest, 0, 0),
149 prb.data_scale, prb.data_shift);
150
151 if (prb.alg == VANILLA_LSTM)
152 copy(prb.mb, prb.dhc, prb.dhc, prb.wc, &src_iter_c(lay, dir_val, 0),
153 &ws_src_iter_c(lay + 1, dir_val, it_dest, 0, 0));
154 }
155 }
156
copy_res_fwd(const prb_t & prb,float * dst_layer_,float * dst_iter_,float * dst_iter_c_,const AOC<const float> & ws_src_layer,const AOC<const float> & ws_src_iter,const AOC<const float> & ws_src_iter_c,rnn_iter_direction_t iter_dir,rnn_layer_direction_t lay_dir,int64_t dir_val,rnn_action_t action)157 void copy_res_fwd(const prb_t &prb, float *dst_layer_, float *dst_iter_,
158 float *dst_iter_c_, const AOC<const float> &ws_src_layer,
159 const AOC<const float> &ws_src_iter,
160 const AOC<const float> &ws_src_iter_c, rnn_iter_direction_t iter_dir,
161 rnn_layer_direction_t lay_dir, int64_t dir_val, rnn_action_t action) {
162 AOC<float> dst_iter(dst_iter_, prb.n_layer, prb.n_dir(), prb.mb, prb.dic);
163 AOC<float> dst_iter_c(
164 dst_iter_c_, prb.n_layer, prb.n_dir(), prb.mb, prb.dhc);
165 AOC<float> dst_layer(dst_layer_, prb.n_iter, prb.mb, prb.dlc(PRIMITIVE));
166 const bool is_layer_deq = (prb.is_u8() && prb.cfg[DST_LAYER].dt != dnnl_u8)
167 || (prb.is_s8() && prb.cfg[DST_LAYER].dt != dnnl_s8);
168 const bool is_iter_deq = (prb.is_u8() && prb.cfg[DST_ITER].dt != dnnl_u8)
169 || (prb.is_s8() && prb.cfg[DST_ITER].dt != dnnl_s8);
170 // Copy dst_layer
171 for (int64_t it = 0; it < prb.n_iter; it++) {
172 for (int64_t nb = 0; nb < prb.mb; nb++) {
173 auto from = &ws_src_layer(prb.n_layer, dir_val, it + 1, nb, 0);
174 auto to = &dst_layer(
175 it, nb, action == action_concat ? prb.dlc(CELL) : 0);
176 copy(1, prb.dlc(CELL), prb.wc, prb.dlc(PRIMITIVE), from, to, action,
177 prb.is_int8());
178
179 if (is_layer_deq) {
180 float data_shift = prb.data_shift;
181 bool do_deq10n = true;
182
183 if (prb.direction == dnnl_bidirectional_sum) {
184 // In `bidir_sum` case, we need to dequantize data only
185 // after the final summation. Also, since we sum two shifted
186 // tensors, we need to enlarge the shift by 2x.
187 do_deq10n = action == action_sum;
188 data_shift *= 2;
189 }
190
191 if (do_deq10n)
192 data_deq10n(1, prb.dlc(CELL), prb.dlc(PRIMITIVE), to,
193 prb.data_scale, data_shift);
194 }
195 }
196 }
197
198 int64_t it_source = (iter_dir == left2right) ? prb.n_iter : 1;
199
200 // Copy dst_iter (and dst_iter_c)
201 for (int64_t lay = 0; lay < prb.n_layer; lay++) {
202 if (prb.alg == VANILLA_LSTM) {
203 copy(prb.mb, prb.dhc, prb.wc, prb.dhc,
204 &ws_src_iter_c(lay + 1, dir_val, it_source, 0, 0),
205 &dst_iter_c(lay, dir_val, 0, 0));
206 }
207
208 copy(prb.mb, prb.dic, prb.wc, prb.dic,
209 &ws_src_iter(lay + 1, dir_val, it_source, 0, 0),
210 &dst_iter(lay, dir_val, 0, 0));
211 if (is_iter_deq)
212 data_deq10n(prb.mb, prb.dic, prb.dic, &dst_iter(lay, dir_val, 0, 0),
213 prb.data_scale, prb.data_shift);
214 }
215 }
216
217 /******************************************************************************/
218 /*************************** Computation Routines *****************************/
219 /******************************************************************************/
220
rnn_cell_fwd(const prb_t & prb,float * dst_layer,float * dst_iter,float * dst_iter_c,float * gates,float * ht,const float * weights_layer,const float * weights_iter,const float * weights_peephole,const float * weights_projection,const float * weights_projection_compensation,const float * bias,const float * src_layer,const float * src_iter,const float * src_iter_c,float * cell_scratchpad_)221 void rnn_cell_fwd(const prb_t &prb, float *dst_layer, float *dst_iter,
222 float *dst_iter_c, float *gates, float *ht, const float *weights_layer,
223 const float *weights_iter, const float *weights_peephole,
224 const float *weights_projection,
225 const float *weights_projection_compensation, const float *bias,
226 const float *src_layer, const float *src_iter, const float *src_iter_c,
227 float *cell_scratchpad_) {
228 if (prb.alg != VANILLA_LSTM) assert(dst_layer == dst_iter);
229
230 switch (prb.alg) {
231 case VANILLA_GRU:
232 gru_fwd(prb, dst_layer, gates, weights_layer, weights_iter, bias,
233 src_layer, src_iter);
234 break;
235 case LBR_GRU:
236 lbr_gru_fwd(prb, dst_layer, gates, weights_layer, weights_iter,
237 bias, src_layer, src_iter, cell_scratchpad_);
238 break;
239 case VANILLA_LSTM:
240 lstm_fwd(prb, dst_layer, dst_iter, dst_iter_c, gates, ht,
241 weights_layer, weights_iter, weights_peephole,
242 weights_projection, weights_projection_compensation, bias,
243 src_layer, src_iter, src_iter_c);
244 break;
245 case VANILLA_RNN:
246 rnn_fwd(prb, dst_layer, gates, weights_layer, weights_iter, bias,
247 src_layer, src_iter);
248 break;
249 default: break;
250 }
251 }
252
rnn_linear_fwd(const prb_t & prb,const float * src_layer_,const float * src_iter_,const float * src_iter_c_,const float * weights_layer_,const float * weights_iter_,const float * weights_peephole_,const float * weights_projection_,const float * bias_,float * dst_layer_,float * dst_iter_,float * dst_iter_c_,const AOC<float> & ws_src_layer,const AOC<float> & ws_src_iter,const AOC<float> & ws_src_iter_c,const AOC<float> & ws_gates,const AOC<float> & ws_ht)253 void rnn_linear_fwd(const prb_t &prb, const float *src_layer_,
254 const float *src_iter_, const float *src_iter_c_,
255 const float *weights_layer_, const float *weights_iter_,
256 const float *weights_peephole_, const float *weights_projection_,
257 const float *bias_, float *dst_layer_, float *dst_iter_,
258 float *dst_iter_c_, const AOC<float> &ws_src_layer,
259 const AOC<float> &ws_src_iter, const AOC<float> &ws_src_iter_c,
260 const AOC<float> &ws_gates, const AOC<float> &ws_ht) {
261 bool is_lbr = prb.alg == LBR_GRU;
262
263 float *bias_with_compensation = nullptr;
264 float *weights_projection_compensation_ = nullptr;
265 if (prb.is_int8()) {
266 bias_with_compensation = new float[prb.n_layer * prb.n_dir()
267 * (prb.n_gates() + is_lbr) * prb.dhc];
268 prepare_bias(prb, bias_with_compensation, bias_, weights_layer_,
269 weights_iter_);
270 bias_ = bias_with_compensation;
271 if (prb.is_lstm_projection()) {
272 weights_projection_compensation_
273 = new float[prb.n_layer * prb.n_dir() * prb.dic];
274 prepare_projection_compensation(
275 prb, weights_projection_compensation_, weights_projection_);
276 }
277 }
278
279 AOC<const float> weights_peephole(
280 weights_peephole_, prb.n_layer, prb.n_dir(), 3 * prb.dhc);
281 AOC<const float> weights_projection(
282 weights_projection_, prb.n_layer, prb.n_dir(), prb.dhc * prb.dic);
283 AOC<const float> weights_projection_compensation(
284 weights_projection_compensation_, prb.n_layer, prb.n_dir(),
285 prb.dic);
286 AOC<const float> bias(bias_, prb.n_layer, prb.n_dir(),
287 (prb.n_gates() + is_lbr) * prb.dhc);
288 AOC<const float> weights_layer(weights_layer_, prb.n_layer, prb.n_dir(),
289 prb.n_gates() * prb.dhc, prb.slc);
290 AOC<const float> weights_iter(weights_iter_, prb.n_layer, prb.n_dir(),
291 prb.n_gates() * prb.dhc, prb.sic);
292
293 int64_t cell_scratchpad_size = is_lbr * prb.mb * prb.n_gates() * prb.dhc;
294 float *cell_scratchpad_ = new float[cell_scratchpad_size];
295 for (int i = 0; i < cell_scratchpad_size; i++) {
296 cell_scratchpad_[i] = NAN;
297 }
298
299 auto process_direction = [&](rnn_iter_direction_t iter_dir,
300 rnn_layer_direction_t lay_dir,
301 int64_t dir_val, rnn_action_t action) {
302 // we first need to copy the initial src_layer and src_iter{,_c} into
303 // ws to simplify the logic of the code
304 BENCHDNN_PRINT(80,
305 "rnn_linear_fwd: call copy_init dir_val = " IFMT "\n", dir_val);
306 copy_init_fwd(prb, ws_src_layer, ws_src_iter, ws_src_iter_c, src_layer_,
307 src_iter_, src_iter_c_, iter_dir, lay_dir, dir_val);
308
309 // We run the grid of computation
310 for (int64_t il = 0; il < prb.n_layer; il++) {
311 for (int64_t it = 0; it < prb.n_iter; it++) {
312 BENCHDNN_PRINT(80,
313 "==== layer = " IFMT " iter = " IFMT " ===\n", il, it);
314 int64_t iter
315 = (iter_dir == left2right) ? it + 1 : prb.n_iter - it;
316 int64_t prev_iter
317 = (iter_dir == left2right) ? iter - 1 : iter + 1;
318 int64_t lay = il + 1;
319 rnn_cell_fwd(prb, &ws_src_layer(lay, dir_val, iter, 0, 0),
320 &ws_src_iter(lay, dir_val, iter, 0, 0),
321 &ws_src_iter_c(lay, dir_val, iter, 0, 0),
322 &ws_gates(lay - 1, dir_val, iter - 1, 0, 0, 0),
323 &ws_ht(lay - 1, dir_val, iter - 1, 0, 0),
324 &weights_layer(lay - 1, dir_val, 0, 0),
325 &weights_iter(lay - 1, dir_val, 0, 0),
326 &weights_peephole(lay - 1, dir_val, 0),
327 &weights_projection(lay - 1, dir_val, 0),
328 &weights_projection_compensation(lay - 1, dir_val, 0),
329 &bias(lay - 1, dir_val, 0),
330 &ws_src_layer(lay - 1, dir_val, iter, 0, 0),
331 &ws_src_iter(lay, dir_val, prev_iter, 0, 0),
332 &ws_src_iter_c(lay, dir_val, prev_iter, 0, 0),
333 cell_scratchpad_);
334 }
335 }
336
337 // Finally we copy the results to the result buffers
338 copy_res_fwd(prb, dst_layer_, dst_iter_, dst_iter_c_, ws_src_layer,
339 ws_src_iter, ws_src_iter_c, iter_dir, lay_dir, dir_val, action);
340 };
341
342 switch (prb.direction) {
343 case dnnl_unidirectional_left2right:
344 process_direction(left2right, bottom2top, 0, action_copy);
345 break;
346 case dnnl_unidirectional_right2left:
347 process_direction(right2left, bottom2top, 0, action_copy);
348 break;
349 case dnnl_bidirectional_sum:
350 process_direction(left2right, bottom2top, 0, action_copy);
351 process_direction(right2left, bottom2top, 1, action_sum);
352 break;
353 case dnnl_bidirectional_concat:
354 process_direction(left2right, bottom2top, 0, action_copy);
355 process_direction(right2left, bottom2top, 1, action_concat);
356 break;
357 default: assert(!"unknown direction"); break;
358 }
359
360 delete[] cell_scratchpad_;
361 delete[] bias_with_compensation;
362 delete[] weights_projection_compensation_;
363 }
364
compute_ref_fwd(const prb_t & prb,dnn_mem_t & src_layer_m,dnn_mem_t & src_iter_m,dnn_mem_t & src_iter_c_m,dnn_mem_t & weights_src_layer_m,dnn_mem_t & weights_src_iter_m,dnn_mem_t & weights_peephole_m,dnn_mem_t & weights_projection_m,dnn_mem_t & bias_m,dnn_mem_t & dst_layer_m,dnn_mem_t & dst_iter_m,dnn_mem_t & dst_iter_c_m)365 void compute_ref_fwd(const prb_t &prb, dnn_mem_t &src_layer_m,
366 dnn_mem_t &src_iter_m, dnn_mem_t &src_iter_c_m,
367 dnn_mem_t &weights_src_layer_m, dnn_mem_t &weights_src_iter_m,
368 dnn_mem_t &weights_peephole_m, dnn_mem_t &weights_projection_m,
369 dnn_mem_t &bias_m, dnn_mem_t &dst_layer_m, dnn_mem_t &dst_iter_m,
370 dnn_mem_t &dst_iter_c_m) {
371 std::vector<float> ws_fwd_buffer;
372 AOC<float> ws_src_layer, ws_src_iter, ws_src_iter_c, ws_gates, ws_ht;
373 prepare_ws_fwd(prb, ws_fwd_buffer, ws_src_layer, ws_src_iter, ws_src_iter_c,
374 ws_gates, ws_ht);
375
376 rnn_linear_fwd(prb, (float *)src_layer_m, (float *)src_iter_m,
377 (float *)src_iter_c_m, (float *)weights_src_layer_m,
378 (float *)weights_src_iter_m, (float *)weights_peephole_m,
379 (float *)weights_projection_m, (float *)bias_m,
380 (float *)dst_layer_m, (float *)dst_iter_m, (float *)dst_iter_c_m,
381 ws_src_layer, ws_src_iter, ws_src_iter_c, ws_gates, ws_ht);
382 }
383
384 } // namespace rnn
385