1 /*******************************************************************************
2 * Copyright 2020-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 /*
18  * Cell execution LSTM projection
19  */
20 
21 #include "common/dnnl_thread.hpp"
22 #include "common/math_utils.hpp"
23 
24 #include "cpu/simple_q10n.hpp"
25 
26 #include "cpu/rnn/postgemm_dispatcher.hpp"
27 
28 namespace dnnl {
29 namespace impl {
30 namespace cpu {
31 
32 using namespace dnnl::impl::utils;
33 using namespace dnnl::impl::math;
34 using namespace rnn_utils;
35 
36 namespace {
37 template <typename dst_layer_t, typename dst_iter_t>
proj_dst_copy(const rnn_utils::rnn_conf_t & rnn,rnn_utils::cell_position_t cell_position,dst_iter_t * dst_iter_,const dst_layer_t * dst_layer_,int block_step)38 void proj_dst_copy(const rnn_utils::rnn_conf_t &rnn,
39         rnn_utils::cell_position_t cell_position, dst_iter_t *dst_iter_,
40         const dst_layer_t *dst_layer_, int block_step) {
41     assert(rnn.dic == rnn.dlc);
42     static_assert(sizeof(dst_layer_t) == sizeof(dst_iter_t),
43             "memcpy requires the same data type size for src and dst");
44     const auto dst_layer_ld = rnn.dst_layer_ld(cell_position, true);
45     const auto dst_iter_ld = rnn.dst_iter_ld(cell_position);
46 
47     // If dst_iter is not nullptr, we need to copy the state to dst_iter
48     if (dst_iter_ != nullptr) {
49         if (rnn.is_brgemm && !rnn.unfused_post_gemm) {
50             for (int i = 0; i < rnn.m_block; i++)
51                 std::memcpy(dst_iter_ + i * dst_iter_ld,
52                         dst_layer_ + i * dst_layer_ld, block_step);
53         } else {
54             parallel_nd(rnn.mb, [&](dim_t i) {
55                 std::memcpy(dst_iter_ + i * dst_iter_ld,
56                         dst_layer_ + i * dst_layer_ld, block_step);
57             });
58         }
59     }
60 }
61 } // namespace
62 
63 template <>
rnn_postgemm_sig(rnn_postgemm_fwd_f32_t::lstm_projection_postgemm)64 rnn_postgemm_sig(rnn_postgemm_fwd_f32_t::lstm_projection_postgemm) {
65     // nothing to do for f32, except copy to dst_iter if needed
66     proj_dst_copy(rnn, cell_position, dst_iter_, dst_layer_, block_step);
67 }
68 
69 template <>
rnn_postgemm_sig(rnn_postgemm_fwd_bf16_t::lstm_projection_postgemm)70 rnn_postgemm_sig(rnn_postgemm_fwd_bf16_t::lstm_projection_postgemm) {
71     const auto dst_layer_ld = rnn.dst_layer_ld(cell_position, true);
72 
73     // Currently, scratch_gates_ contains the output of the projection
74     const int n_elem = block_step / (int)sizeof(dst_layer_t);
75 
76     const int m_block
77             = (rnn.is_brgemm && !rnn.unfused_post_gemm) ? rnn.m_block : rnn.mb;
78 
79     for (int i = 0; i < m_block; i++)
80         cvt_float_to_bfloat16((bfloat16_t *)dst_layer_ + i * dst_layer_ld,
81                 (float *)scratch_gates_ + i * rnn.scratch_gates_ld, n_elem);
82 
83     // we copy to dst_iter if necessary
84     proj_dst_copy(rnn, cell_position, dst_iter_, dst_layer_, block_step);
85 }
86 
87 template <>
rnn_postgemm_sig(rnn_postgemm_fwd_u8_t::lstm_projection_postgemm)88 rnn_postgemm_sig(rnn_postgemm_fwd_u8_t::lstm_projection_postgemm) {
89     // Here, we use
90     // - scratch_gates to pass the s32 output of the projection
91     // - src_iter_c to pass the projection compensation
92 
93     const auto dst_layer_ld = rnn.dst_layer_ld(cell_position, true);
94     const auto w_proj_comp = static_cast<const float *>(src_iter_c_);
95 
96     const float data_shift = pd_->attr()->rnn_data_qparams_.shift_;
97     const float data_scale = pd_->attr()->rnn_data_qparams_.scale_;
98 
99     const auto quantize_f32_u8 = [&](float f) {
100         float qf = f * data_scale + data_shift;
101         qf = nstl::min(qf, 255.0f);
102         qf = nstl::max(qf, 0.0f);
103         return qz_a1b0<float, dst_layer_t>()(qf);
104     };
105 
106     const auto dequantize_s32_f32 = [&](gemm_acc_t s, int j) {
107         const float wscale
108                 = pd_->attr()->rnn_weights_projection_qparams_.mask_ == 0
109                 ? weights_scales_[0]
110                 : weights_scales_[j];
111         const float wcomp = w_proj_comp[j] * data_shift;
112 
113         return (saturate<float>(s) - wcomp) / (wscale * data_scale);
114     };
115 
116     auto postgemm_call = [&](int i) {
117         const int n_elem = block_step / (int)sizeof(dst_layer_t);
118         PRAGMA_OMP_SIMD()
119         for (int j = 0; j < n_elem; j++) {
120             const int scratch_off = i * rnn.scratch_gates_ld + j;
121             const int dst_off = i * dst_layer_ld + j;
122             const float tmp
123                     = dequantize_s32_f32(scratch_gates_[scratch_off], j);
124             dst_layer_[dst_off] = quantize_f32_u8(tmp);
125         }
126     };
127     if (rnn.is_brgemm && !rnn.unfused_post_gemm) {
128         for (int i = 0; i < rnn.m_block; i++)
129             postgemm_call(i);
130     } else {
131         parallel_nd(rnn.mb, [&](dim_t i) { postgemm_call(i); });
132     }
133     proj_dst_copy(rnn, cell_position, dst_iter_, dst_layer_, block_step);
134 }
135 
136 template <>
rnn_postgemm_sig(rnn_postgemm_fwd_s8_t::lstm_projection_postgemm)137 rnn_postgemm_sig(rnn_postgemm_fwd_s8_t::lstm_projection_postgemm) {
138     // Here, we use
139     // - scratch_gates to pass the s32 output of the projection
140     // - no need to pass the projection compensation for s8s8 amx
141     const auto dst_layer_ld = rnn.dst_layer_ld(cell_position, true);
142 
143     const float data_shift = pd_->attr()->rnn_data_qparams_.shift_;
144     const float data_scale = pd_->attr()->rnn_data_qparams_.scale_;
145 
146     const auto quantize_f32_s8 = [&](float f) {
147         const float qf = f * data_scale + data_shift;
148         return qz_a1b0<float, dst_layer_t>()(qf);
149     };
150 
151     const auto dequantize_s32_f32 = [&](gemm_acc_t s, int j) {
152         const float wscale
153                 = pd_->attr()->rnn_weights_projection_qparams_.mask_ == 0
154                 ? weights_scales_[0]
155                 : weights_scales_[j];
156 
157         return (saturate<float>(s)) / (wscale * data_scale);
158     };
159 
160     const auto postgemm_call = [&](dim_t i) {
161         const int n_elem = block_step / (int)sizeof(dst_layer_t);
162         PRAGMA_OMP_SIMD()
163         for (int j = 0; j < n_elem; j++) {
164             const int scratch_off = i * rnn.scratch_gates_ld + j;
165             const int dst_off = i * dst_layer_ld + j;
166             const float tmp
167                     = dequantize_s32_f32(scratch_gates_[scratch_off], j);
168             dst_layer_[dst_off] = quantize_f32_s8(tmp);
169         }
170     };
171     if (rnn.is_brgemm && !rnn.unfused_post_gemm) {
172         for (int i = 0; i < rnn.m_block; i++)
173             postgemm_call(i);
174     } else {
175         parallel_nd(rnn.mb, [&](dim_t i) { postgemm_call(i); });
176     }
177     proj_dst_copy(rnn, cell_position, dst_iter_, dst_layer_, block_step);
178 }
179 
180 template <>
rnn_postgemm_sig(rnn_postgemm_bwd_f32_t::lstm_projection_postgemm)181 rnn_postgemm_sig(rnn_postgemm_bwd_f32_t::lstm_projection_postgemm) {
182     assert(!"unsupported");
183 }
184 
185 template <>
rnn_postgemm_sig(rnn_postgemm_bwd_bf16_t::lstm_projection_postgemm)186 rnn_postgemm_sig(rnn_postgemm_bwd_bf16_t::lstm_projection_postgemm) {
187     assert(!"unsupported");
188 }
189 
190 } // namespace cpu
191 } // namespace impl
192 } // namespace dnnl
193