1 /*******************************************************************************
2 * Copyright 2018-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 /*
18  * Cell execution LSTM
19  */
20 
21 #include "common/bit_cast.hpp"
22 #include "common/dnnl_thread.hpp"
23 #include "common/math_utils.hpp"
24 
25 #include "cpu/simple_q10n.hpp"
26 
27 #include "cpu/rnn/postgemm_dispatcher.hpp"
28 
29 namespace dnnl {
30 namespace impl {
31 namespace cpu {
32 
33 using namespace dnnl::impl::utils;
34 using namespace dnnl::impl::math;
35 using namespace rnn_utils;
36 #define AOC array_offset_calculator
37 
38 template <typename T1, typename T2, typename T3, typename T4, typename T5,
39         typename src_data_t, typename scratch_data_t>
gru_fwd_part1_postgemm_template(T1 func1,T2 to_src,T3 acc_to_float,T4 src_to_float,T5 reinterpret_as_acc,const float * scales,const rnn_utils::rnn_conf_t & rnn,rnn_utils::cell_position_t cell_position,src_data_t * ws_gates_,scratch_data_t * scratch_gates_,src_data_t * dst_layer_,src_data_t * dst_iter_,const src_data_t * src_iter_,const void * bias_)40 void gru_fwd_part1_postgemm_template(T1 func1, T2 to_src, T3 acc_to_float,
41         T4 src_to_float, T5 reinterpret_as_acc, const float *scales,
42         const rnn_utils::rnn_conf_t &rnn,
43         rnn_utils::cell_position_t cell_position, src_data_t *ws_gates_,
44         scratch_data_t *scratch_gates_, src_data_t *dst_layer_,
45         src_data_t *dst_iter_, const src_data_t *src_iter_, const void *bias_) {
46     const ws_gates_aoc<src_data_t> ws_gates(rnn, ws_gates_);
47     const scratch_gates_aoc<scratch_data_t> scratch_gates(rnn, scratch_gates_);
48     const auto bias_aoc = rnn_utils::make_raw_aoc(
49             bias_, types::data_type_size(rnn.bias_dt), rnn.n_bias, rnn.dhc);
50     const auto bias = [&](int gate_id, int dhc_id) {
51         return to_float(bias_aoc(gate_id, dhc_id), rnn.bias_dt);
52     };
53 
54     const auto dst_iter_ld = rnn.dst_iter_ld(cell_position);
55     const auto dst_layer_ld = rnn.dst_layer_ld(cell_position);
56     const auto src_iter_ld = rnn.src_iter_ld(cell_position);
57 
58     const ws_states_layer_aoc<src_data_t> dst_layer(
59             rnn, dst_layer_, dst_layer_ld);
60     const ws_states_iter_aoc<src_data_t> dst_iter(rnn, dst_iter_, dst_iter_ld);
61     const ws_states_iter_aoc<const src_data_t> src_iter(
62             rnn, src_iter_, src_iter_ld);
63 
64     parallel_nd(rnn.mb, [&](dim_t i) {
65         PRAGMA_OMP_SIMD()
66         for (int j = 0; j < rnn.dhc; j++) {
67             const auto G0 // default func1 is sigmoid
68                     = func1(scales,
69                             acc_to_float(scratch_gates(i, 0, j), 0, j)
70                                     + bias(0, j));
71             const auto G1 // default func1 is sigmoid
72                     = func1(scales + 1,
73                             acc_to_float(scratch_gates(i, 1, j), 1, j)
74                                     + bias(1, j));
75             /* TODO: Can be optimized for fwd_training by using ws_gates instead of scratch_gates in p2 */
76             scratch_gates(i, 0, j) = reinterpret_as_acc(G0);
77             const auto t = to_src(src_to_float(src_iter(i, j)) * G1);
78             if (dst_layer_) dst_layer(i, j) = t;
79             if (dst_iter_) dst_iter(i, j) = t;
80 
81             if (rnn.is_training) {
82                 ws_gates(i, 0, j) = to_src(G0);
83                 ws_gates(i, 1, j) = to_src(G1);
84             }
85         }
86     });
87 }
88 
89 template <typename T1, typename T2, typename T3, typename T4, typename T5,
90         typename src_data_t, typename scratch_data_t>
gru_fwd_part2_postgemm_template(T1 func1,T2 to_src,T3 acc_to_float,T4 src_to_float,T5 reinterpret_as_float,const float * scales,const rnn_utils::rnn_conf_t & rnn,rnn_utils::cell_position_t cell_position,src_data_t * ws_gates_,scratch_data_t * scratch_gates_,src_data_t * dst_layer_,src_data_t * dst_iter_,const src_data_t * src_iter_,const void * bias_)91 void gru_fwd_part2_postgemm_template(T1 func1, T2 to_src, T3 acc_to_float,
92         T4 src_to_float, T5 reinterpret_as_float, const float *scales,
93         const rnn_utils::rnn_conf_t &rnn,
94         rnn_utils::cell_position_t cell_position, src_data_t *ws_gates_,
95         scratch_data_t *scratch_gates_, src_data_t *dst_layer_,
96         src_data_t *dst_iter_, const src_data_t *src_iter_, const void *bias_) {
97     const ws_gates_aoc<src_data_t> ws_gates(rnn, ws_gates_);
98     const scratch_gates_aoc<scratch_data_t> scratch_gates(rnn, scratch_gates_);
99     const auto bias_aoc = rnn_utils::make_raw_aoc(
100             bias_, types::data_type_size(rnn.bias_dt), rnn.n_bias, rnn.dhc);
101     const auto bias = [&](int gate_id, int dhc_id) {
102         return to_float(bias_aoc(gate_id, dhc_id), rnn.bias_dt);
103     };
104 
105     const auto dst_layer_ld = rnn.dst_layer_ld(cell_position);
106     const auto dst_iter_ld = rnn.dst_iter_ld(cell_position);
107     const auto src_iter_ld = rnn.src_iter_ld(cell_position);
108     const ws_states_layer_aoc<src_data_t> dst_layer(
109             rnn, dst_layer_, dst_layer_ld);
110     const ws_states_iter_aoc<src_data_t> dst_iter(rnn, dst_iter_, dst_iter_ld);
111     const ws_states_iter_aoc<const src_data_t> src_iter(
112             rnn, src_iter_, src_iter_ld);
113 
114     parallel_nd(rnn.mb, [&](dim_t i) {
115         PRAGMA_OMP_SIMD()
116         for (int j = 0; j < rnn.dhc; j++) {
117             const auto G0 = reinterpret_as_float(scratch_gates(i, 0, j));
118             const auto G2 // default func1 is tanh
119                     = func1(scales + 2,
120                             acc_to_float(scratch_gates(i, 2, j), 2, j)
121                                     + bias(2, j));
122 
123             const auto tmp = to_src(
124                     src_to_float(src_iter(i, j)) * G0 + (1.0f - G0) * G2);
125             if (dst_layer_ != nullptr) dst_layer(i, j) = tmp;
126             if (dst_iter_ != nullptr) dst_iter(i, j) = tmp;
127 
128             if (rnn.is_training) { ws_gates(i, 2, j) = to_src(G2); }
129         }
130     });
131 }
132 
133 template <>
rnn_postgemm_sig(rnn_postgemm_fwd_f32_t::gru_part1_postgemm)134 rnn_postgemm_sig(rnn_postgemm_fwd_f32_t::gru_part1_postgemm) {
135     const float *scales = pd_->attr()->rnn_tparams_.scales_;
136     const auto linear_f
137             = [](const float *scale, float a) { return *scale * a; };
138     const auto logistic_f = [](const float *scale, float a) {
139         return logistic_fwd<float>(a);
140     };
141 
142     const auto deq_id = [](float f, int i, int j) { return f; };
143     const auto id = [](float f) { return f; };
144 
145     if (!pd_->attr()->rnn_tparams_.test_mode_)
146         gru_fwd_part1_postgemm_template(logistic_f, id, deq_id, id, id, scales,
147                 rnn, cell_position, ws_gates_, scratch_gates_, dst_layer_,
148                 dst_iter_, src_iter_, bias_);
149     else
150         gru_fwd_part1_postgemm_template(linear_f, id, deq_id, id, id, scales,
151                 rnn, cell_position, ws_gates_, scratch_gates_, dst_layer_,
152                 dst_iter_, src_iter_, bias_);
153 }
154 
155 template <>
rnn_postgemm_sig(rnn_postgemm_fwd_f32_t::gru_part2_postgemm)156 rnn_postgemm_sig(rnn_postgemm_fwd_f32_t::gru_part2_postgemm) {
157     const float *scales = pd_->attr()->rnn_tparams_.scales_;
158     const auto linear_f
159             = [](const float *scale, float a) { return *scale * a; };
160     const auto tanh_f
161             = [](const float *scale, float a) { return tanh_fwd<float>(a); };
162 
163     const auto deq_id = [](float f, int i, int j) { return f; };
164     const auto id = [](float f) { return f; };
165 
166     if (!pd_->attr()->rnn_tparams_.test_mode_)
167         gru_fwd_part2_postgemm_template(tanh_f, id, deq_id, id, id, scales, rnn,
168                 cell_position, ws_gates_, scratch_gates_, dst_layer_, dst_iter_,
169                 src_iter_, bias_);
170     else
171         gru_fwd_part2_postgemm_template(linear_f, id, deq_id, id, id, scales,
172                 rnn, cell_position, ws_gates_, scratch_gates_, dst_layer_,
173                 dst_iter_, src_iter_, bias_);
174 }
175 
176 template <>
rnn_postgemm_sig(rnn_postgemm_fwd_bf16_t::gru_part1_postgemm)177 rnn_postgemm_sig(rnn_postgemm_fwd_bf16_t::gru_part1_postgemm) {
178     const float *scales = pd_->attr()->rnn_tparams_.scales_;
179     const auto linear_f
180             = [](const float *scale, float a) { return *scale * a; };
181     const auto logistic_f = [](const float *scale, float a) {
182         return logistic_fwd<float>(a);
183     };
184 
185     const auto dn_cvt_f32_bf16 = [](float f) { return bfloat16_t(f); };
186     const auto up_cvt_bf16_f32 = [](bfloat16_t b) { return float(b); };
187     const auto deq_id = [](float f, int i, int j) { return f; };
188     const auto id = [](float f) { return f; };
189 
190     if (!pd_->attr()->rnn_tparams_.test_mode_)
191         gru_fwd_part1_postgemm_template(logistic_f, dn_cvt_f32_bf16, deq_id,
192                 up_cvt_bf16_f32, id, scales, rnn, cell_position, ws_gates_,
193                 scratch_gates_, dst_layer_, dst_iter_, src_iter_, bias_);
194     else
195         gru_fwd_part1_postgemm_template(linear_f, dn_cvt_f32_bf16, deq_id,
196                 up_cvt_bf16_f32, id, scales, rnn, cell_position, ws_gates_,
197                 scratch_gates_, dst_layer_, dst_iter_, src_iter_, bias_);
198 }
199 template <>
rnn_postgemm_sig(rnn_postgemm_fwd_bf16_t::gru_part2_postgemm)200 rnn_postgemm_sig(rnn_postgemm_fwd_bf16_t::gru_part2_postgemm) {
201     const float *scales = pd_->attr()->rnn_tparams_.scales_;
202     const auto linear_f
203             = [](const float *scale, float a) { return *scale * a; };
204     const auto tanh_f
205             = [](const float *scale, float a) { return tanh_fwd<float>(a); };
206 
207     const auto dn_cvt_f32_bf16 = [](float f) { return bfloat16_t(f); };
208     const auto up_cvt_bf16_f32 = [](bfloat16_t b) { return float(b); };
209     const auto deq_id = [](float f, int i, int j) { return f; };
210     const auto id = [](float f) { return f; };
211 
212     if (!pd_->attr()->rnn_tparams_.test_mode_)
213         gru_fwd_part2_postgemm_template(tanh_f, dn_cvt_f32_bf16, deq_id,
214                 up_cvt_bf16_f32, id, scales, rnn, cell_position, ws_gates_,
215                 scratch_gates_, dst_layer_, dst_iter_, src_iter_, bias_);
216     else
217         gru_fwd_part2_postgemm_template(linear_f, dn_cvt_f32_bf16, deq_id,
218                 up_cvt_bf16_f32, id, scales, rnn, cell_position, ws_gates_,
219                 scratch_gates_, dst_layer_, dst_iter_, src_iter_, bias_);
220 }
221 
222 template <>
rnn_postgemm_sig(rnn_postgemm_fwd_u8_t::gru_part1_postgemm)223 rnn_postgemm_sig(rnn_postgemm_fwd_u8_t::gru_part1_postgemm) {
224     const float *scales = pd_->attr()->rnn_tparams_.scales_;
225     const auto linear_f
226             = [](const float *scale, float a) { return *scale * a; };
227     const auto logistic_f = [](const float *scale, float a) {
228         return logistic_fwd<float>(a);
229     };
230 
231     const float *weights_scales = pd_->attr()->rnn_weights_qparams_.scales_;
232     const float data_shift = pd_->attr()->rnn_data_qparams_.shift_;
233     const float data_scale = pd_->attr()->rnn_data_qparams_.scale_;
234 
235     const auto quantize_f32_u8 = [&](float f) {
236         float qf = f * data_scale + data_shift;
237         qf = nstl::min(qf, 255.0f);
238         qf = nstl::max(qf, 0.0f);
239         return (dst_layer_t)mxcsr_cvt(qf);
240     };
241 
242     const auto dequantize_s32_f32 = [&](gemm_acc_t s, int gate, int j) {
243         const float wscale = pd_->attr()->rnn_weights_qparams_.mask_ == 0
244                 ? weights_scales[0]
245                 : weights_scales[gate * rnn.dhc + j];
246         return saturate<float>(s) * (1.f / (wscale * data_scale));
247     };
248 
249     const auto dequantize_u8_f32 = [&](src_iter_t s) {
250         return (static_cast<float>(s) - data_shift) * (1.f / data_scale);
251     };
252 
253     const auto reinterpret_f32_s32
254             = [](float a) { return bit_cast<gemm_acc_t>(a); };
255 
256     if (!pd_->attr()->rnn_tparams_.test_mode_)
257         gru_fwd_part1_postgemm_template(logistic_f, quantize_f32_u8,
258                 dequantize_s32_f32, dequantize_u8_f32, reinterpret_f32_s32,
259                 scales, rnn, cell_position, ws_gates_, scratch_gates_,
260                 dst_layer_, dst_iter_, src_iter_, bias_);
261     else
262         gru_fwd_part1_postgemm_template(linear_f, quantize_f32_u8,
263                 dequantize_s32_f32, dequantize_u8_f32, reinterpret_f32_s32,
264                 scales, rnn, cell_position, ws_gates_, scratch_gates_,
265                 dst_layer_, dst_iter_, src_iter_, bias_);
266 }
267 
268 template <>
rnn_postgemm_sig(rnn_postgemm_fwd_u8_t::gru_part2_postgemm)269 rnn_postgemm_sig(rnn_postgemm_fwd_u8_t::gru_part2_postgemm) {
270     const float *scales = pd_->attr()->rnn_tparams_.scales_;
271     const auto linear_f
272             = [](const float *scale, float a) { return *scale * a; };
273     const auto tanh_f
274             = [](const float *scale, float a) { return tanh_fwd<float>(a); };
275 
276     const float *weights_scales = pd_->attr()->rnn_weights_qparams_.scales_;
277     const float data_shift = pd_->attr()->rnn_data_qparams_.shift_;
278     const float data_scale = pd_->attr()->rnn_data_qparams_.scale_;
279 
280     const auto quantize_f32_u8 = [&](float f) {
281         float qf = f * data_scale + data_shift;
282         qf = nstl::min(qf, 255.0f);
283         qf = nstl::max(qf, 0.0f);
284         return (dst_layer_t)mxcsr_cvt(qf);
285     };
286 
287     const auto dequantize_s32_f32 = [&](gemm_acc_t s, int gate, int j) {
288         const float wscale = pd_->attr()->rnn_weights_qparams_.mask_ == 0
289                 ? weights_scales[0]
290                 : weights_scales[gate * rnn.dhc + j];
291         return saturate<float>(s) * (1.f / (wscale * data_scale));
292     };
293 
294     const auto dequantize_u8_f32 = [&](src_iter_t s) {
295         return (static_cast<float>(s) - data_shift) * (1.f / data_scale);
296     };
297 
298     const auto reinterpret_s32_f32
299             = [](gemm_acc_t a) { return bit_cast<float>(a); };
300 
301     if (!pd_->attr()->rnn_tparams_.test_mode_)
302         gru_fwd_part2_postgemm_template(tanh_f, quantize_f32_u8,
303                 dequantize_s32_f32, dequantize_u8_f32, reinterpret_s32_f32,
304                 scales, rnn, cell_position, ws_gates_, scratch_gates_,
305                 dst_layer_, dst_iter_, src_iter_, bias_);
306     else
307         gru_fwd_part2_postgemm_template(linear_f, quantize_f32_u8,
308                 dequantize_s32_f32, dequantize_u8_f32, reinterpret_s32_f32,
309                 scales, rnn, cell_position, ws_gates_, scratch_gates_,
310                 dst_layer_, dst_iter_, src_iter_, bias_);
311 }
312 
313 template <>
rnn_postgemm_sig(rnn_postgemm_fwd_s8_t::gru_part1_postgemm)314 rnn_postgemm_sig(rnn_postgemm_fwd_s8_t::gru_part1_postgemm) {
315     assert(!"GRU signed int8 is not supported");
316 }
317 
318 template <>
rnn_postgemm_sig(rnn_postgemm_fwd_s8_t::gru_part2_postgemm)319 rnn_postgemm_sig(rnn_postgemm_fwd_s8_t::gru_part2_postgemm) {
320     assert(!"GRU signed int8 is not supported");
321 }
322 
323 template <typename T, typename src_data_t, typename acc_data_t,
324         typename scratch_data_t>
gru_bwd_part1_postgemm_template(T to_src,const rnn_utils::rnn_conf_t & rnn,cell_position_t cell_position,src_data_t * ws_gates_,scratch_data_t * scratch_gates_,src_data_t * dst_layer_,const src_data_t * src_iter_,acc_data_t * diff_src_iter_,acc_data_t * diff_dst_iter_,acc_data_t * diff_dst_layer_)325 void gru_bwd_part1_postgemm_template(T to_src, const rnn_utils::rnn_conf_t &rnn,
326         cell_position_t cell_position, src_data_t *ws_gates_,
327         scratch_data_t *scratch_gates_, src_data_t *dst_layer_,
328         const src_data_t *src_iter_, acc_data_t *diff_src_iter_,
329         acc_data_t *diff_dst_iter_, acc_data_t *diff_dst_layer_) {
330     const auto src_iter_ld = rnn.src_iter_ld(cell_position);
331 
332     const ws_states_iter_aoc<const src_data_t> src_iter(
333             rnn, src_iter_, src_iter_ld);
334     const ws_gates_aoc<src_data_t> ws_gates(rnn, ws_gates_);
335     const ws_gates_aoc<scratch_data_t> scratch_gates(rnn, scratch_gates_);
336     const ws_diff_states_iter_aoc<acc_data_t> diff_src_iter(
337             rnn, diff_src_iter_);
338     const ws_diff_states_iter_aoc<acc_data_t> diff_dst_iter(
339             rnn, diff_dst_iter_);
340     const ws_diff_states_layer_aoc<acc_data_t> diff_dst_layer(
341             rnn, diff_dst_layer_);
342 
343     // dG2^ = dh * (1 - G0) * (1 - G2^2)
344     // dG0^ = dh * (ht-1 - G2) * u * (1 - G0)
345     // dht-1 (part) = dh * G0
346     parallel_nd(rnn.mb, [&](dim_t i) {
347         PRAGMA_OMP_SIMD()
348         for (int j = 0; j < rnn.dhc; j++) {
349             const float h = src_iter(i, j);
350             const float dHt = diff_dst_iter(i, j) + diff_dst_layer(i, j);
351             const float dG2 = (1.0f - ws_gates(i, 0, j)) * dHt
352                     * one_m_square(ws_gates(i, 2, j));
353             const float dG0 = (h - ws_gates(i, 2, j)) * dHt
354                     * x_m_square(ws_gates(i, 0, j));
355 
356             diff_src_iter(i, j) = dHt * ws_gates(i, 0, j);
357             scratch_gates(i, 0, j) = to_src(dG0);
358             scratch_gates(i, 2, j) = to_src(dG2);
359         }
360     });
361 }
362 
363 template <typename T, typename src_data_t, typename acc_data_t,
364         typename scratch_data_t>
gru_bwd_part2_postgemm_template(T to_src,const rnn_utils::rnn_conf_t & rnn,cell_position_t cell_position,src_data_t * ws_gates_,scratch_data_t * scratch_gates_,src_data_t * dst_layer_,const src_data_t * src_iter_,acc_data_t * diff_src_layer_,acc_data_t * diff_src_iter_,acc_data_t * diff_dst_iter_,acc_data_t * diff_dst_layer_,scratch_data_t * scratch_cell_)365 void gru_bwd_part2_postgemm_template(T to_src, const rnn_utils::rnn_conf_t &rnn,
366         cell_position_t cell_position, src_data_t *ws_gates_,
367         scratch_data_t *scratch_gates_, src_data_t *dst_layer_,
368         const src_data_t *src_iter_, acc_data_t *diff_src_layer_,
369         acc_data_t *diff_src_iter_, acc_data_t *diff_dst_iter_,
370         acc_data_t *diff_dst_layer_, scratch_data_t *scratch_cell_) {
371     const auto src_iter_ld = rnn.src_iter_ld(cell_position);
372     // auto dst_ld = rnn.dst_ld(cell_position);
373     // ws_states_layer_aoc<src_data_t> dst_layer(rnn, dst_layer_, dst_ld);
374     const ws_states_iter_aoc<const src_data_t> src_iter(
375             rnn, src_iter_, src_iter_ld);
376     const ws_gates_aoc<src_data_t> ws_gates(rnn, ws_gates_);
377     const ws_gates_aoc<scratch_data_t> scratch_gates(rnn, scratch_gates_);
378     const ws_diff_states_layer_aoc<acc_data_t> diff_dst_layer(
379             rnn, diff_dst_layer_);
380     const ws_diff_states_iter_aoc<acc_data_t> diff_dst_iter(
381             rnn, diff_dst_iter_);
382 
383     const ws_diff_states_layer_aoc<acc_data_t> dhG1(rnn, diff_src_layer_);
384     const ws_diff_states_iter_aoc<acc_data_t> diff_src_iter(
385             rnn, diff_src_iter_);
386     const AOC<scratch_data_t, 2> hG1(
387             scratch_cell_, rnn.ws_states_layer_nld, rnn.ws_states_layer_ld);
388 
389     // dG1^ = d(hG1) * h * G1 * (1 - G1)
390     // dht-1 (part) += d(hG1) * G1
391     // h * G1 (required for dWh)
392     parallel_nd(rnn.mb, [&](dim_t i) {
393         PRAGMA_OMP_SIMD()
394         for (int j = 0; j < rnn.dhc; j++) {
395             const float h = src_iter(i, j);
396             const float G1 = ws_gates(i, 1, j);
397             diff_src_iter(i, j) += dhG1(i, j) * G1;
398             scratch_gates(i, 1, j) = to_src(dhG1(i, j) * h * x_m_square(G1));
399             hG1(i, j) = to_src(G1 * h);
400         }
401     });
402 }
403 
404 template <>
rnn_postgemm_sig(rnn_postgemm_bwd_f32_t::gru_part1_postgemm)405 rnn_postgemm_sig(rnn_postgemm_bwd_f32_t::gru_part1_postgemm) {
406     const auto to_src = [](float a) { return a; };
407 
408     gru_bwd_part1_postgemm_template(to_src, rnn, cell_position, ws_gates_,
409             scratch_gates_, dst_layer_, src_iter_, diff_src_iter_,
410             diff_dst_iter_, diff_dst_layer_);
411 }
412 
413 template <>
rnn_postgemm_sig(rnn_postgemm_bwd_f32_t::gru_part2_postgemm)414 rnn_postgemm_sig(rnn_postgemm_bwd_f32_t::gru_part2_postgemm) {
415     const auto to_src = [](float a) { return a; };
416 
417     gru_bwd_part2_postgemm_template(to_src, rnn, cell_position, ws_gates_,
418             scratch_gates_, dst_layer_, src_iter_, diff_src_layer_,
419             diff_src_iter_, diff_dst_iter_, diff_dst_layer_, scratch_cell_);
420 }
421 
422 template <>
rnn_postgemm_sig(rnn_postgemm_bwd_bf16_t::gru_part1_postgemm)423 rnn_postgemm_sig(rnn_postgemm_bwd_bf16_t::gru_part1_postgemm) {
424     const auto to_src = [](float a) { return bfloat16_t(a); };
425 
426     gru_bwd_part1_postgemm_template(to_src, rnn, cell_position, ws_gates_,
427             scratch_gates_, dst_layer_, src_iter_, diff_src_iter_,
428             diff_dst_iter_, diff_dst_layer_);
429 }
430 
431 template <>
rnn_postgemm_sig(rnn_postgemm_bwd_bf16_t::gru_part2_postgemm)432 rnn_postgemm_sig(rnn_postgemm_bwd_bf16_t::gru_part2_postgemm) {
433     const auto to_src = [](float a) { return bfloat16_t(a); };
434 
435     gru_bwd_part2_postgemm_template(to_src, rnn, cell_position, ws_gates_,
436             scratch_gates_, dst_layer_, src_iter_, diff_src_layer_,
437             diff_src_iter_, diff_dst_iter_, diff_dst_layer_, scratch_cell_);
438 }
439 
440 #undef AOC
441 } // namespace cpu
442 } // namespace impl
443 } // namespace dnnl
444