1 /*******************************************************************************
2 * Copyright 2018-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 /*
18 * Cell execution LSTM
19 */
20
21 #include "common/bit_cast.hpp"
22 #include "common/dnnl_thread.hpp"
23 #include "common/math_utils.hpp"
24
25 #include "cpu/simple_q10n.hpp"
26
27 #include "cpu/rnn/postgemm_dispatcher.hpp"
28
29 namespace dnnl {
30 namespace impl {
31 namespace cpu {
32
33 using namespace dnnl::impl::utils;
34 using namespace dnnl::impl::math;
35 using namespace rnn_utils;
36 #define AOC array_offset_calculator
37
38 template <typename T1, typename T2, typename T3, typename T4, typename T5,
39 typename src_data_t, typename scratch_data_t>
gru_fwd_part1_postgemm_template(T1 func1,T2 to_src,T3 acc_to_float,T4 src_to_float,T5 reinterpret_as_acc,const float * scales,const rnn_utils::rnn_conf_t & rnn,rnn_utils::cell_position_t cell_position,src_data_t * ws_gates_,scratch_data_t * scratch_gates_,src_data_t * dst_layer_,src_data_t * dst_iter_,const src_data_t * src_iter_,const void * bias_)40 void gru_fwd_part1_postgemm_template(T1 func1, T2 to_src, T3 acc_to_float,
41 T4 src_to_float, T5 reinterpret_as_acc, const float *scales,
42 const rnn_utils::rnn_conf_t &rnn,
43 rnn_utils::cell_position_t cell_position, src_data_t *ws_gates_,
44 scratch_data_t *scratch_gates_, src_data_t *dst_layer_,
45 src_data_t *dst_iter_, const src_data_t *src_iter_, const void *bias_) {
46 const ws_gates_aoc<src_data_t> ws_gates(rnn, ws_gates_);
47 const scratch_gates_aoc<scratch_data_t> scratch_gates(rnn, scratch_gates_);
48 const auto bias_aoc = rnn_utils::make_raw_aoc(
49 bias_, types::data_type_size(rnn.bias_dt), rnn.n_bias, rnn.dhc);
50 const auto bias = [&](int gate_id, int dhc_id) {
51 return to_float(bias_aoc(gate_id, dhc_id), rnn.bias_dt);
52 };
53
54 const auto dst_iter_ld = rnn.dst_iter_ld(cell_position);
55 const auto dst_layer_ld = rnn.dst_layer_ld(cell_position);
56 const auto src_iter_ld = rnn.src_iter_ld(cell_position);
57
58 const ws_states_layer_aoc<src_data_t> dst_layer(
59 rnn, dst_layer_, dst_layer_ld);
60 const ws_states_iter_aoc<src_data_t> dst_iter(rnn, dst_iter_, dst_iter_ld);
61 const ws_states_iter_aoc<const src_data_t> src_iter(
62 rnn, src_iter_, src_iter_ld);
63
64 parallel_nd(rnn.mb, [&](dim_t i) {
65 PRAGMA_OMP_SIMD()
66 for (int j = 0; j < rnn.dhc; j++) {
67 const auto G0 // default func1 is sigmoid
68 = func1(scales,
69 acc_to_float(scratch_gates(i, 0, j), 0, j)
70 + bias(0, j));
71 const auto G1 // default func1 is sigmoid
72 = func1(scales + 1,
73 acc_to_float(scratch_gates(i, 1, j), 1, j)
74 + bias(1, j));
75 /* TODO: Can be optimized for fwd_training by using ws_gates instead of scratch_gates in p2 */
76 scratch_gates(i, 0, j) = reinterpret_as_acc(G0);
77 const auto t = to_src(src_to_float(src_iter(i, j)) * G1);
78 if (dst_layer_) dst_layer(i, j) = t;
79 if (dst_iter_) dst_iter(i, j) = t;
80
81 if (rnn.is_training) {
82 ws_gates(i, 0, j) = to_src(G0);
83 ws_gates(i, 1, j) = to_src(G1);
84 }
85 }
86 });
87 }
88
89 template <typename T1, typename T2, typename T3, typename T4, typename T5,
90 typename src_data_t, typename scratch_data_t>
gru_fwd_part2_postgemm_template(T1 func1,T2 to_src,T3 acc_to_float,T4 src_to_float,T5 reinterpret_as_float,const float * scales,const rnn_utils::rnn_conf_t & rnn,rnn_utils::cell_position_t cell_position,src_data_t * ws_gates_,scratch_data_t * scratch_gates_,src_data_t * dst_layer_,src_data_t * dst_iter_,const src_data_t * src_iter_,const void * bias_)91 void gru_fwd_part2_postgemm_template(T1 func1, T2 to_src, T3 acc_to_float,
92 T4 src_to_float, T5 reinterpret_as_float, const float *scales,
93 const rnn_utils::rnn_conf_t &rnn,
94 rnn_utils::cell_position_t cell_position, src_data_t *ws_gates_,
95 scratch_data_t *scratch_gates_, src_data_t *dst_layer_,
96 src_data_t *dst_iter_, const src_data_t *src_iter_, const void *bias_) {
97 const ws_gates_aoc<src_data_t> ws_gates(rnn, ws_gates_);
98 const scratch_gates_aoc<scratch_data_t> scratch_gates(rnn, scratch_gates_);
99 const auto bias_aoc = rnn_utils::make_raw_aoc(
100 bias_, types::data_type_size(rnn.bias_dt), rnn.n_bias, rnn.dhc);
101 const auto bias = [&](int gate_id, int dhc_id) {
102 return to_float(bias_aoc(gate_id, dhc_id), rnn.bias_dt);
103 };
104
105 const auto dst_layer_ld = rnn.dst_layer_ld(cell_position);
106 const auto dst_iter_ld = rnn.dst_iter_ld(cell_position);
107 const auto src_iter_ld = rnn.src_iter_ld(cell_position);
108 const ws_states_layer_aoc<src_data_t> dst_layer(
109 rnn, dst_layer_, dst_layer_ld);
110 const ws_states_iter_aoc<src_data_t> dst_iter(rnn, dst_iter_, dst_iter_ld);
111 const ws_states_iter_aoc<const src_data_t> src_iter(
112 rnn, src_iter_, src_iter_ld);
113
114 parallel_nd(rnn.mb, [&](dim_t i) {
115 PRAGMA_OMP_SIMD()
116 for (int j = 0; j < rnn.dhc; j++) {
117 const auto G0 = reinterpret_as_float(scratch_gates(i, 0, j));
118 const auto G2 // default func1 is tanh
119 = func1(scales + 2,
120 acc_to_float(scratch_gates(i, 2, j), 2, j)
121 + bias(2, j));
122
123 const auto tmp = to_src(
124 src_to_float(src_iter(i, j)) * G0 + (1.0f - G0) * G2);
125 if (dst_layer_ != nullptr) dst_layer(i, j) = tmp;
126 if (dst_iter_ != nullptr) dst_iter(i, j) = tmp;
127
128 if (rnn.is_training) { ws_gates(i, 2, j) = to_src(G2); }
129 }
130 });
131 }
132
133 template <>
rnn_postgemm_sig(rnn_postgemm_fwd_f32_t::gru_part1_postgemm)134 rnn_postgemm_sig(rnn_postgemm_fwd_f32_t::gru_part1_postgemm) {
135 const float *scales = pd_->attr()->rnn_tparams_.scales_;
136 const auto linear_f
137 = [](const float *scale, float a) { return *scale * a; };
138 const auto logistic_f = [](const float *scale, float a) {
139 return logistic_fwd<float>(a);
140 };
141
142 const auto deq_id = [](float f, int i, int j) { return f; };
143 const auto id = [](float f) { return f; };
144
145 if (!pd_->attr()->rnn_tparams_.test_mode_)
146 gru_fwd_part1_postgemm_template(logistic_f, id, deq_id, id, id, scales,
147 rnn, cell_position, ws_gates_, scratch_gates_, dst_layer_,
148 dst_iter_, src_iter_, bias_);
149 else
150 gru_fwd_part1_postgemm_template(linear_f, id, deq_id, id, id, scales,
151 rnn, cell_position, ws_gates_, scratch_gates_, dst_layer_,
152 dst_iter_, src_iter_, bias_);
153 }
154
155 template <>
rnn_postgemm_sig(rnn_postgemm_fwd_f32_t::gru_part2_postgemm)156 rnn_postgemm_sig(rnn_postgemm_fwd_f32_t::gru_part2_postgemm) {
157 const float *scales = pd_->attr()->rnn_tparams_.scales_;
158 const auto linear_f
159 = [](const float *scale, float a) { return *scale * a; };
160 const auto tanh_f
161 = [](const float *scale, float a) { return tanh_fwd<float>(a); };
162
163 const auto deq_id = [](float f, int i, int j) { return f; };
164 const auto id = [](float f) { return f; };
165
166 if (!pd_->attr()->rnn_tparams_.test_mode_)
167 gru_fwd_part2_postgemm_template(tanh_f, id, deq_id, id, id, scales, rnn,
168 cell_position, ws_gates_, scratch_gates_, dst_layer_, dst_iter_,
169 src_iter_, bias_);
170 else
171 gru_fwd_part2_postgemm_template(linear_f, id, deq_id, id, id, scales,
172 rnn, cell_position, ws_gates_, scratch_gates_, dst_layer_,
173 dst_iter_, src_iter_, bias_);
174 }
175
176 template <>
rnn_postgemm_sig(rnn_postgemm_fwd_bf16_t::gru_part1_postgemm)177 rnn_postgemm_sig(rnn_postgemm_fwd_bf16_t::gru_part1_postgemm) {
178 const float *scales = pd_->attr()->rnn_tparams_.scales_;
179 const auto linear_f
180 = [](const float *scale, float a) { return *scale * a; };
181 const auto logistic_f = [](const float *scale, float a) {
182 return logistic_fwd<float>(a);
183 };
184
185 const auto dn_cvt_f32_bf16 = [](float f) { return bfloat16_t(f); };
186 const auto up_cvt_bf16_f32 = [](bfloat16_t b) { return float(b); };
187 const auto deq_id = [](float f, int i, int j) { return f; };
188 const auto id = [](float f) { return f; };
189
190 if (!pd_->attr()->rnn_tparams_.test_mode_)
191 gru_fwd_part1_postgemm_template(logistic_f, dn_cvt_f32_bf16, deq_id,
192 up_cvt_bf16_f32, id, scales, rnn, cell_position, ws_gates_,
193 scratch_gates_, dst_layer_, dst_iter_, src_iter_, bias_);
194 else
195 gru_fwd_part1_postgemm_template(linear_f, dn_cvt_f32_bf16, deq_id,
196 up_cvt_bf16_f32, id, scales, rnn, cell_position, ws_gates_,
197 scratch_gates_, dst_layer_, dst_iter_, src_iter_, bias_);
198 }
199 template <>
rnn_postgemm_sig(rnn_postgemm_fwd_bf16_t::gru_part2_postgemm)200 rnn_postgemm_sig(rnn_postgemm_fwd_bf16_t::gru_part2_postgemm) {
201 const float *scales = pd_->attr()->rnn_tparams_.scales_;
202 const auto linear_f
203 = [](const float *scale, float a) { return *scale * a; };
204 const auto tanh_f
205 = [](const float *scale, float a) { return tanh_fwd<float>(a); };
206
207 const auto dn_cvt_f32_bf16 = [](float f) { return bfloat16_t(f); };
208 const auto up_cvt_bf16_f32 = [](bfloat16_t b) { return float(b); };
209 const auto deq_id = [](float f, int i, int j) { return f; };
210 const auto id = [](float f) { return f; };
211
212 if (!pd_->attr()->rnn_tparams_.test_mode_)
213 gru_fwd_part2_postgemm_template(tanh_f, dn_cvt_f32_bf16, deq_id,
214 up_cvt_bf16_f32, id, scales, rnn, cell_position, ws_gates_,
215 scratch_gates_, dst_layer_, dst_iter_, src_iter_, bias_);
216 else
217 gru_fwd_part2_postgemm_template(linear_f, dn_cvt_f32_bf16, deq_id,
218 up_cvt_bf16_f32, id, scales, rnn, cell_position, ws_gates_,
219 scratch_gates_, dst_layer_, dst_iter_, src_iter_, bias_);
220 }
221
222 template <>
rnn_postgemm_sig(rnn_postgemm_fwd_u8_t::gru_part1_postgemm)223 rnn_postgemm_sig(rnn_postgemm_fwd_u8_t::gru_part1_postgemm) {
224 const float *scales = pd_->attr()->rnn_tparams_.scales_;
225 const auto linear_f
226 = [](const float *scale, float a) { return *scale * a; };
227 const auto logistic_f = [](const float *scale, float a) {
228 return logistic_fwd<float>(a);
229 };
230
231 const float *weights_scales = pd_->attr()->rnn_weights_qparams_.scales_;
232 const float data_shift = pd_->attr()->rnn_data_qparams_.shift_;
233 const float data_scale = pd_->attr()->rnn_data_qparams_.scale_;
234
235 const auto quantize_f32_u8 = [&](float f) {
236 float qf = f * data_scale + data_shift;
237 qf = nstl::min(qf, 255.0f);
238 qf = nstl::max(qf, 0.0f);
239 return (dst_layer_t)mxcsr_cvt(qf);
240 };
241
242 const auto dequantize_s32_f32 = [&](gemm_acc_t s, int gate, int j) {
243 const float wscale = pd_->attr()->rnn_weights_qparams_.mask_ == 0
244 ? weights_scales[0]
245 : weights_scales[gate * rnn.dhc + j];
246 return saturate<float>(s) * (1.f / (wscale * data_scale));
247 };
248
249 const auto dequantize_u8_f32 = [&](src_iter_t s) {
250 return (static_cast<float>(s) - data_shift) * (1.f / data_scale);
251 };
252
253 const auto reinterpret_f32_s32
254 = [](float a) { return bit_cast<gemm_acc_t>(a); };
255
256 if (!pd_->attr()->rnn_tparams_.test_mode_)
257 gru_fwd_part1_postgemm_template(logistic_f, quantize_f32_u8,
258 dequantize_s32_f32, dequantize_u8_f32, reinterpret_f32_s32,
259 scales, rnn, cell_position, ws_gates_, scratch_gates_,
260 dst_layer_, dst_iter_, src_iter_, bias_);
261 else
262 gru_fwd_part1_postgemm_template(linear_f, quantize_f32_u8,
263 dequantize_s32_f32, dequantize_u8_f32, reinterpret_f32_s32,
264 scales, rnn, cell_position, ws_gates_, scratch_gates_,
265 dst_layer_, dst_iter_, src_iter_, bias_);
266 }
267
268 template <>
rnn_postgemm_sig(rnn_postgemm_fwd_u8_t::gru_part2_postgemm)269 rnn_postgemm_sig(rnn_postgemm_fwd_u8_t::gru_part2_postgemm) {
270 const float *scales = pd_->attr()->rnn_tparams_.scales_;
271 const auto linear_f
272 = [](const float *scale, float a) { return *scale * a; };
273 const auto tanh_f
274 = [](const float *scale, float a) { return tanh_fwd<float>(a); };
275
276 const float *weights_scales = pd_->attr()->rnn_weights_qparams_.scales_;
277 const float data_shift = pd_->attr()->rnn_data_qparams_.shift_;
278 const float data_scale = pd_->attr()->rnn_data_qparams_.scale_;
279
280 const auto quantize_f32_u8 = [&](float f) {
281 float qf = f * data_scale + data_shift;
282 qf = nstl::min(qf, 255.0f);
283 qf = nstl::max(qf, 0.0f);
284 return (dst_layer_t)mxcsr_cvt(qf);
285 };
286
287 const auto dequantize_s32_f32 = [&](gemm_acc_t s, int gate, int j) {
288 const float wscale = pd_->attr()->rnn_weights_qparams_.mask_ == 0
289 ? weights_scales[0]
290 : weights_scales[gate * rnn.dhc + j];
291 return saturate<float>(s) * (1.f / (wscale * data_scale));
292 };
293
294 const auto dequantize_u8_f32 = [&](src_iter_t s) {
295 return (static_cast<float>(s) - data_shift) * (1.f / data_scale);
296 };
297
298 const auto reinterpret_s32_f32
299 = [](gemm_acc_t a) { return bit_cast<float>(a); };
300
301 if (!pd_->attr()->rnn_tparams_.test_mode_)
302 gru_fwd_part2_postgemm_template(tanh_f, quantize_f32_u8,
303 dequantize_s32_f32, dequantize_u8_f32, reinterpret_s32_f32,
304 scales, rnn, cell_position, ws_gates_, scratch_gates_,
305 dst_layer_, dst_iter_, src_iter_, bias_);
306 else
307 gru_fwd_part2_postgemm_template(linear_f, quantize_f32_u8,
308 dequantize_s32_f32, dequantize_u8_f32, reinterpret_s32_f32,
309 scales, rnn, cell_position, ws_gates_, scratch_gates_,
310 dst_layer_, dst_iter_, src_iter_, bias_);
311 }
312
313 template <>
rnn_postgemm_sig(rnn_postgemm_fwd_s8_t::gru_part1_postgemm)314 rnn_postgemm_sig(rnn_postgemm_fwd_s8_t::gru_part1_postgemm) {
315 assert(!"GRU signed int8 is not supported");
316 }
317
318 template <>
rnn_postgemm_sig(rnn_postgemm_fwd_s8_t::gru_part2_postgemm)319 rnn_postgemm_sig(rnn_postgemm_fwd_s8_t::gru_part2_postgemm) {
320 assert(!"GRU signed int8 is not supported");
321 }
322
323 template <typename T, typename src_data_t, typename acc_data_t,
324 typename scratch_data_t>
gru_bwd_part1_postgemm_template(T to_src,const rnn_utils::rnn_conf_t & rnn,cell_position_t cell_position,src_data_t * ws_gates_,scratch_data_t * scratch_gates_,src_data_t * dst_layer_,const src_data_t * src_iter_,acc_data_t * diff_src_iter_,acc_data_t * diff_dst_iter_,acc_data_t * diff_dst_layer_)325 void gru_bwd_part1_postgemm_template(T to_src, const rnn_utils::rnn_conf_t &rnn,
326 cell_position_t cell_position, src_data_t *ws_gates_,
327 scratch_data_t *scratch_gates_, src_data_t *dst_layer_,
328 const src_data_t *src_iter_, acc_data_t *diff_src_iter_,
329 acc_data_t *diff_dst_iter_, acc_data_t *diff_dst_layer_) {
330 const auto src_iter_ld = rnn.src_iter_ld(cell_position);
331
332 const ws_states_iter_aoc<const src_data_t> src_iter(
333 rnn, src_iter_, src_iter_ld);
334 const ws_gates_aoc<src_data_t> ws_gates(rnn, ws_gates_);
335 const ws_gates_aoc<scratch_data_t> scratch_gates(rnn, scratch_gates_);
336 const ws_diff_states_iter_aoc<acc_data_t> diff_src_iter(
337 rnn, diff_src_iter_);
338 const ws_diff_states_iter_aoc<acc_data_t> diff_dst_iter(
339 rnn, diff_dst_iter_);
340 const ws_diff_states_layer_aoc<acc_data_t> diff_dst_layer(
341 rnn, diff_dst_layer_);
342
343 // dG2^ = dh * (1 - G0) * (1 - G2^2)
344 // dG0^ = dh * (ht-1 - G2) * u * (1 - G0)
345 // dht-1 (part) = dh * G0
346 parallel_nd(rnn.mb, [&](dim_t i) {
347 PRAGMA_OMP_SIMD()
348 for (int j = 0; j < rnn.dhc; j++) {
349 const float h = src_iter(i, j);
350 const float dHt = diff_dst_iter(i, j) + diff_dst_layer(i, j);
351 const float dG2 = (1.0f - ws_gates(i, 0, j)) * dHt
352 * one_m_square(ws_gates(i, 2, j));
353 const float dG0 = (h - ws_gates(i, 2, j)) * dHt
354 * x_m_square(ws_gates(i, 0, j));
355
356 diff_src_iter(i, j) = dHt * ws_gates(i, 0, j);
357 scratch_gates(i, 0, j) = to_src(dG0);
358 scratch_gates(i, 2, j) = to_src(dG2);
359 }
360 });
361 }
362
363 template <typename T, typename src_data_t, typename acc_data_t,
364 typename scratch_data_t>
gru_bwd_part2_postgemm_template(T to_src,const rnn_utils::rnn_conf_t & rnn,cell_position_t cell_position,src_data_t * ws_gates_,scratch_data_t * scratch_gates_,src_data_t * dst_layer_,const src_data_t * src_iter_,acc_data_t * diff_src_layer_,acc_data_t * diff_src_iter_,acc_data_t * diff_dst_iter_,acc_data_t * diff_dst_layer_,scratch_data_t * scratch_cell_)365 void gru_bwd_part2_postgemm_template(T to_src, const rnn_utils::rnn_conf_t &rnn,
366 cell_position_t cell_position, src_data_t *ws_gates_,
367 scratch_data_t *scratch_gates_, src_data_t *dst_layer_,
368 const src_data_t *src_iter_, acc_data_t *diff_src_layer_,
369 acc_data_t *diff_src_iter_, acc_data_t *diff_dst_iter_,
370 acc_data_t *diff_dst_layer_, scratch_data_t *scratch_cell_) {
371 const auto src_iter_ld = rnn.src_iter_ld(cell_position);
372 // auto dst_ld = rnn.dst_ld(cell_position);
373 // ws_states_layer_aoc<src_data_t> dst_layer(rnn, dst_layer_, dst_ld);
374 const ws_states_iter_aoc<const src_data_t> src_iter(
375 rnn, src_iter_, src_iter_ld);
376 const ws_gates_aoc<src_data_t> ws_gates(rnn, ws_gates_);
377 const ws_gates_aoc<scratch_data_t> scratch_gates(rnn, scratch_gates_);
378 const ws_diff_states_layer_aoc<acc_data_t> diff_dst_layer(
379 rnn, diff_dst_layer_);
380 const ws_diff_states_iter_aoc<acc_data_t> diff_dst_iter(
381 rnn, diff_dst_iter_);
382
383 const ws_diff_states_layer_aoc<acc_data_t> dhG1(rnn, diff_src_layer_);
384 const ws_diff_states_iter_aoc<acc_data_t> diff_src_iter(
385 rnn, diff_src_iter_);
386 const AOC<scratch_data_t, 2> hG1(
387 scratch_cell_, rnn.ws_states_layer_nld, rnn.ws_states_layer_ld);
388
389 // dG1^ = d(hG1) * h * G1 * (1 - G1)
390 // dht-1 (part) += d(hG1) * G1
391 // h * G1 (required for dWh)
392 parallel_nd(rnn.mb, [&](dim_t i) {
393 PRAGMA_OMP_SIMD()
394 for (int j = 0; j < rnn.dhc; j++) {
395 const float h = src_iter(i, j);
396 const float G1 = ws_gates(i, 1, j);
397 diff_src_iter(i, j) += dhG1(i, j) * G1;
398 scratch_gates(i, 1, j) = to_src(dhG1(i, j) * h * x_m_square(G1));
399 hG1(i, j) = to_src(G1 * h);
400 }
401 });
402 }
403
404 template <>
rnn_postgemm_sig(rnn_postgemm_bwd_f32_t::gru_part1_postgemm)405 rnn_postgemm_sig(rnn_postgemm_bwd_f32_t::gru_part1_postgemm) {
406 const auto to_src = [](float a) { return a; };
407
408 gru_bwd_part1_postgemm_template(to_src, rnn, cell_position, ws_gates_,
409 scratch_gates_, dst_layer_, src_iter_, diff_src_iter_,
410 diff_dst_iter_, diff_dst_layer_);
411 }
412
413 template <>
rnn_postgemm_sig(rnn_postgemm_bwd_f32_t::gru_part2_postgemm)414 rnn_postgemm_sig(rnn_postgemm_bwd_f32_t::gru_part2_postgemm) {
415 const auto to_src = [](float a) { return a; };
416
417 gru_bwd_part2_postgemm_template(to_src, rnn, cell_position, ws_gates_,
418 scratch_gates_, dst_layer_, src_iter_, diff_src_layer_,
419 diff_src_iter_, diff_dst_iter_, diff_dst_layer_, scratch_cell_);
420 }
421
422 template <>
rnn_postgemm_sig(rnn_postgemm_bwd_bf16_t::gru_part1_postgemm)423 rnn_postgemm_sig(rnn_postgemm_bwd_bf16_t::gru_part1_postgemm) {
424 const auto to_src = [](float a) { return bfloat16_t(a); };
425
426 gru_bwd_part1_postgemm_template(to_src, rnn, cell_position, ws_gates_,
427 scratch_gates_, dst_layer_, src_iter_, diff_src_iter_,
428 diff_dst_iter_, diff_dst_layer_);
429 }
430
431 template <>
rnn_postgemm_sig(rnn_postgemm_bwd_bf16_t::gru_part2_postgemm)432 rnn_postgemm_sig(rnn_postgemm_bwd_bf16_t::gru_part2_postgemm) {
433 const auto to_src = [](float a) { return bfloat16_t(a); };
434
435 gru_bwd_part2_postgemm_template(to_src, rnn, cell_position, ws_gates_,
436 scratch_gates_, dst_layer_, src_iter_, diff_src_layer_,
437 diff_src_iter_, diff_dst_iter_, diff_dst_layer_, scratch_cell_);
438 }
439
440 #undef AOC
441 } // namespace cpu
442 } // namespace impl
443 } // namespace dnnl
444