1 /*******************************************************************************
2 * Copyright 2019-2020 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #ifndef CPU_X64_RNN_JIT_UNI_GRU_CELL_POSTGEMM_2_BWD_HPP
18 #define CPU_X64_RNN_JIT_UNI_GRU_CELL_POSTGEMM_2_BWD_HPP
19 
20 #include "cpu/x64/rnn/jit_uni_rnn_common_postgemm.hpp"
21 
22 namespace dnnl {
23 namespace impl {
24 namespace cpu {
25 namespace x64 {
26 
27 template <cpu_isa_t isa, impl::data_type_t src_data_t,
28         impl::data_type_t scratch_data_t>
29 struct jit_uni_gru_cell_postgemm_part2_bwd : public jit_uni_rnn_postgemm {
30     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_gru_cell_postgemm_part2_bwd)
31 
jit_uni_gru_cell_postgemm_part2_bwddnnl::impl::cpu::x64::jit_uni_gru_cell_postgemm_part2_bwd32     jit_uni_gru_cell_postgemm_part2_bwd(
33             const rnn_utils::rnn_conf_t &rnn, const rnn_pd_t *pd)
34         : jit_uni_rnn_postgemm(rnn, pd) {}
35 
~jit_uni_gru_cell_postgemm_part2_bwddnnl::impl::cpu::x64::jit_uni_gru_cell_postgemm_part2_bwd36     ~jit_uni_gru_cell_postgemm_part2_bwd() {}
37 
initdnnl::impl::cpu::x64::jit_uni_gru_cell_postgemm_part2_bwd38     status_t init(data_type_t sdt) override {
39         jit_uni_rnn_postgemm::init(src_data_t);
40         return create_kernel();
41     }
42 
43 protected:
44     // register size in bytes
45     using Vmm = typename cpu_isa_traits<isa>::Vmm;
46     size_t vlen = cpu_isa_traits<isa>::vlen;
47     size_t vlen_scratch
48             = vlen / (sizeof(float) / types::data_type_size(scratch_data_t));
49     size_t hstate_dt_size = sizeof(float);
50     size_t gate_dt_size = types::data_type_size(scratch_data_t);
51     size_t scratch_dt_size = types::data_type_size(scratch_data_t);
52 
generatednnl::impl::cpu::x64::jit_uni_gru_cell_postgemm_part2_bwd53     void generate() override {
54         using namespace Xbyak;
55 
56         // Labels declaration
57         Label vector_loop_start_label, vector_loop_inc_regs,
58                 vector_loop_end_label;
59         Label rem_loop_start_label, rem_loop_inc_regs, rem_loop_end_label;
60 
61         // Register map
62         Reg64 loop_cnt(rbx); // loop counter
63 
64         // We skip vmm0 as it can be used by the injector for masks on sse4.1
65         enum {
66             dG1_idx = 1,
67             dhG1_idx = 2,
68             hG1_idx = 3,
69             G1_idx = 4,
70             dH_idx = 5,
71             tmp1_idx = 6,
72             h_idx = 7
73         };
74 
75         // We start code generations here
76         preamble();
77 
78         // extract addresses passed as parameter
79         auto addr_ws_gates_reg = abi_param1;
80         auto addr_scratch_gates_reg = abi_param2;
81         // auto addr_diff_states_t_lp1_reg = abi_param3; // not needed
82         // auto addr_diff_states_tp1_l_reg = abi_param4; // not needed
83 #ifdef _WIN32
84         auto addr_diff_states_t_l_reg = r10;
85         auto addr_states_tm1_l_reg = r11;
86         auto addr_scratch_cell_reg = r12;
87         // auto addr_ws_grid_reg = rsi; // not needed
88         auto addr_dhG1_reg = rsi;
89         auto base_args = get_stack_params_address();
90         mov(addr_diff_states_t_l_reg, ptr[base_args]);
91         mov(addr_states_tm1_l_reg, ptr[base_args + 8]);
92         mov(addr_scratch_cell_reg, ptr[base_args + 16]);
93         // mov(addr_ws_grid_reg, ptr[base_args + 24]);
94         mov(addr_dhG1_reg, ptr[base_args + 32]);
95 #else
96         auto addr_diff_states_t_l_reg = abi_param5;
97         auto addr_states_tm1_l_reg = abi_param6;
98         auto addr_scratch_cell_reg = r10;
99         // auto addr_ws_grid_reg = r11; // not needed
100         auto addr_dhG1_reg = r11;
101         auto base_args = get_stack_params_address();
102         mov(addr_scratch_cell_reg, ptr[base_args]);
103         // mov(addr_ws_grid_reg, ptr[base_args + 8]);
104         mov(addr_dhG1_reg, ptr[base_args + 16]);
105 #endif
106 
107         // helper lambda to address the gates and biases
108         auto sg_addr = [&](int i) {
109             return ptr[addr_scratch_gates_reg + i * rnn_.dhc * scratch_dt_size];
110         };
111         auto wg_addr = [&](int i) {
112             return ptr[addr_ws_gates_reg + i * rnn_.dhc * gate_dt_size];
113         };
114 
115         // initialize registers with addresses and constants
116         init_regs(vlen);
117 
118         mov(loop_cnt, rnn_.dhc * scratch_dt_size);
119         cmp(loop_cnt, vlen_scratch);
120         jl(vector_loop_end_label, Xbyak::CodeGenerator::T_NEAR);
121 
122         L(vector_loop_start_label);
123         {
124             Vmm dG1(dG1_idx), dhG1(dhG1_idx), hG1(hG1_idx), G1(G1_idx),
125                     dH(dH_idx), tmp1(tmp1_idx), h(h_idx);
126 
127             to_float<src_data_t>(G1, wg_addr(1), vlen);
128             to_float<src_data_t>(h, ptr[addr_states_tm1_l_reg], vlen);
129 
130             // compute dG1
131             uni_vmovups(dG1, G1);
132             uni_vmovups(tmp1, G1);
133             uni_vfnmadd231ps(dG1, tmp1, tmp1); // (G1 - G1^2)
134             uni_vmulps(dG1, dG1, h);
135             uni_vmovups(dhG1, ptr[addr_dhG1_reg]);
136             uni_vmulps(dG1, dG1, dhG1); // dhG1 * h * (G0 - G0^2) * dHt
137 
138             // compute hG1
139             uni_vmovups(hG1, G1);
140             uni_vmulps(hG1, hG1, h);
141 
142             // compute diff_states_t_l = diff_states_t_l + dhG1 * G1
143             uni_vmovups(dH, ptr[addr_diff_states_t_l_reg]);
144             uni_vfmadd231ps(dH, dhG1, G1);
145 
146             // downconvert and write data
147             to_src<scratch_data_t>(sg_addr(1), dG1, vlen);
148             to_src<scratch_data_t>(ptr[addr_scratch_cell_reg], hG1, vlen);
149             uni_vmovups(ptr[addr_diff_states_t_l_reg], dH);
150 
151             // increment address pointers
152             add(addr_ws_gates_reg, vlen_scratch);
153             add(addr_scratch_gates_reg, vlen_scratch);
154             add(addr_dhG1_reg, vlen);
155             add(addr_diff_states_t_l_reg, vlen);
156             add(addr_states_tm1_l_reg, vlen_scratch);
157             add(addr_scratch_cell_reg, vlen_scratch);
158             inc_regs(vlen);
159 
160             // increment loop counter
161             sub(loop_cnt, vlen_scratch);
162             cmp(loop_cnt, vlen_scratch);
163             jge(vector_loop_start_label);
164         }
165         L(vector_loop_end_label);
166 
167         cmp(loop_cnt, 0);
168         je(rem_loop_end_label, Xbyak::CodeGenerator::T_NEAR);
169         // Same code as above, we just use movuss for accessing inputs
170         // TODO: smarter handling of tails with Zmm -> Ymm -> Xmm -> scalar
171         L(rem_loop_start_label);
172         {
173             Xmm dG1(dG1_idx), dhG1(dhG1_idx), hG1(hG1_idx), G1(G1_idx),
174                     dH(dH_idx), tmp1(tmp1_idx), h(h_idx);
175 
176             to_float<src_data_t>(G1, wg_addr(1), hstate_dt_size);
177             to_float<src_data_t>(h, ptr[addr_states_tm1_l_reg], hstate_dt_size);
178 
179             // compute dG1
180             uni_vmovss(dG1, G1);
181             uni_vmovss(tmp1, G1);
182             uni_vfnmadd231ps(dG1, tmp1, tmp1); // (G1 - G1^2)
183             uni_vmulss(dG1, dG1, h);
184             uni_vmovss(dhG1, ptr[addr_dhG1_reg]);
185             uni_vmulss(dG1, dG1, dhG1); // dhG1 * h * (G0 - G0^2) * dHt
186 
187             // compute hG1
188             uni_vmovss(hG1, G1);
189             uni_vmulss(hG1, hG1, h);
190 
191             // compute diff_states_t_l = diff_states_t_l + dhG1 * G1
192             uni_vmovss(dH, ptr[addr_diff_states_t_l_reg]);
193             uni_vfmadd231ps(dH, dhG1, G1);
194 
195             // downconvert and write data
196             to_src<scratch_data_t>(sg_addr(1), dG1, hstate_dt_size);
197             to_src<scratch_data_t>(
198                     ptr[addr_scratch_cell_reg], hG1, hstate_dt_size);
199             uni_vmovss(ptr[addr_diff_states_t_l_reg], dH);
200 
201             // increment address pointers
202             add(addr_ws_gates_reg, scratch_dt_size);
203             add(addr_scratch_gates_reg, scratch_dt_size);
204             add(addr_dhG1_reg, hstate_dt_size);
205             add(addr_diff_states_t_l_reg, hstate_dt_size);
206             add(addr_states_tm1_l_reg, scratch_dt_size);
207             add(addr_scratch_cell_reg, scratch_dt_size);
208             inc_regs(hstate_dt_size);
209 
210             // increment loop counter
211             sub(loop_cnt, scratch_dt_size);
212             cmp(loop_cnt, 0);
213             jg(rem_loop_start_label);
214         }
215         L(rem_loop_end_label);
216 
217         postamble();
218 
219         init_table(vlen);
220     }
221 };
222 
223 } // namespace x64
224 } // namespace cpu
225 } // namespace impl
226 } // namespace dnnl
227 
228 #endif
229