1 /******************************************************************************* 2 * Copyright 2019-2020 Intel Corporation 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 *******************************************************************************/ 16 17 #ifndef CPU_X64_RNN_JIT_UNI_GRU_CELL_POSTGEMM_2_BWD_HPP 18 #define CPU_X64_RNN_JIT_UNI_GRU_CELL_POSTGEMM_2_BWD_HPP 19 20 #include "cpu/x64/rnn/jit_uni_rnn_common_postgemm.hpp" 21 22 namespace dnnl { 23 namespace impl { 24 namespace cpu { 25 namespace x64 { 26 27 template <cpu_isa_t isa, impl::data_type_t src_data_t, 28 impl::data_type_t scratch_data_t> 29 struct jit_uni_gru_cell_postgemm_part2_bwd : public jit_uni_rnn_postgemm { 30 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_gru_cell_postgemm_part2_bwd) 31 jit_uni_gru_cell_postgemm_part2_bwddnnl::impl::cpu::x64::jit_uni_gru_cell_postgemm_part2_bwd32 jit_uni_gru_cell_postgemm_part2_bwd( 33 const rnn_utils::rnn_conf_t &rnn, const rnn_pd_t *pd) 34 : jit_uni_rnn_postgemm(rnn, pd) {} 35 ~jit_uni_gru_cell_postgemm_part2_bwddnnl::impl::cpu::x64::jit_uni_gru_cell_postgemm_part2_bwd36 ~jit_uni_gru_cell_postgemm_part2_bwd() {} 37 initdnnl::impl::cpu::x64::jit_uni_gru_cell_postgemm_part2_bwd38 status_t init(data_type_t sdt) override { 39 jit_uni_rnn_postgemm::init(src_data_t); 40 return create_kernel(); 41 } 42 43 protected: 44 // register size in bytes 45 using Vmm = typename cpu_isa_traits<isa>::Vmm; 46 size_t vlen = cpu_isa_traits<isa>::vlen; 47 size_t vlen_scratch 48 = vlen / (sizeof(float) / types::data_type_size(scratch_data_t)); 49 size_t hstate_dt_size = sizeof(float); 50 size_t gate_dt_size = types::data_type_size(scratch_data_t); 51 size_t scratch_dt_size = types::data_type_size(scratch_data_t); 52 generatednnl::impl::cpu::x64::jit_uni_gru_cell_postgemm_part2_bwd53 void generate() override { 54 using namespace Xbyak; 55 56 // Labels declaration 57 Label vector_loop_start_label, vector_loop_inc_regs, 58 vector_loop_end_label; 59 Label rem_loop_start_label, rem_loop_inc_regs, rem_loop_end_label; 60 61 // Register map 62 Reg64 loop_cnt(rbx); // loop counter 63 64 // We skip vmm0 as it can be used by the injector for masks on sse4.1 65 enum { 66 dG1_idx = 1, 67 dhG1_idx = 2, 68 hG1_idx = 3, 69 G1_idx = 4, 70 dH_idx = 5, 71 tmp1_idx = 6, 72 h_idx = 7 73 }; 74 75 // We start code generations here 76 preamble(); 77 78 // extract addresses passed as parameter 79 auto addr_ws_gates_reg = abi_param1; 80 auto addr_scratch_gates_reg = abi_param2; 81 // auto addr_diff_states_t_lp1_reg = abi_param3; // not needed 82 // auto addr_diff_states_tp1_l_reg = abi_param4; // not needed 83 #ifdef _WIN32 84 auto addr_diff_states_t_l_reg = r10; 85 auto addr_states_tm1_l_reg = r11; 86 auto addr_scratch_cell_reg = r12; 87 // auto addr_ws_grid_reg = rsi; // not needed 88 auto addr_dhG1_reg = rsi; 89 auto base_args = get_stack_params_address(); 90 mov(addr_diff_states_t_l_reg, ptr[base_args]); 91 mov(addr_states_tm1_l_reg, ptr[base_args + 8]); 92 mov(addr_scratch_cell_reg, ptr[base_args + 16]); 93 // mov(addr_ws_grid_reg, ptr[base_args + 24]); 94 mov(addr_dhG1_reg, ptr[base_args + 32]); 95 #else 96 auto addr_diff_states_t_l_reg = abi_param5; 97 auto addr_states_tm1_l_reg = abi_param6; 98 auto addr_scratch_cell_reg = r10; 99 // auto addr_ws_grid_reg = r11; // not needed 100 auto addr_dhG1_reg = r11; 101 auto base_args = get_stack_params_address(); 102 mov(addr_scratch_cell_reg, ptr[base_args]); 103 // mov(addr_ws_grid_reg, ptr[base_args + 8]); 104 mov(addr_dhG1_reg, ptr[base_args + 16]); 105 #endif 106 107 // helper lambda to address the gates and biases 108 auto sg_addr = [&](int i) { 109 return ptr[addr_scratch_gates_reg + i * rnn_.dhc * scratch_dt_size]; 110 }; 111 auto wg_addr = [&](int i) { 112 return ptr[addr_ws_gates_reg + i * rnn_.dhc * gate_dt_size]; 113 }; 114 115 // initialize registers with addresses and constants 116 init_regs(vlen); 117 118 mov(loop_cnt, rnn_.dhc * scratch_dt_size); 119 cmp(loop_cnt, vlen_scratch); 120 jl(vector_loop_end_label, Xbyak::CodeGenerator::T_NEAR); 121 122 L(vector_loop_start_label); 123 { 124 Vmm dG1(dG1_idx), dhG1(dhG1_idx), hG1(hG1_idx), G1(G1_idx), 125 dH(dH_idx), tmp1(tmp1_idx), h(h_idx); 126 127 to_float<src_data_t>(G1, wg_addr(1), vlen); 128 to_float<src_data_t>(h, ptr[addr_states_tm1_l_reg], vlen); 129 130 // compute dG1 131 uni_vmovups(dG1, G1); 132 uni_vmovups(tmp1, G1); 133 uni_vfnmadd231ps(dG1, tmp1, tmp1); // (G1 - G1^2) 134 uni_vmulps(dG1, dG1, h); 135 uni_vmovups(dhG1, ptr[addr_dhG1_reg]); 136 uni_vmulps(dG1, dG1, dhG1); // dhG1 * h * (G0 - G0^2) * dHt 137 138 // compute hG1 139 uni_vmovups(hG1, G1); 140 uni_vmulps(hG1, hG1, h); 141 142 // compute diff_states_t_l = diff_states_t_l + dhG1 * G1 143 uni_vmovups(dH, ptr[addr_diff_states_t_l_reg]); 144 uni_vfmadd231ps(dH, dhG1, G1); 145 146 // downconvert and write data 147 to_src<scratch_data_t>(sg_addr(1), dG1, vlen); 148 to_src<scratch_data_t>(ptr[addr_scratch_cell_reg], hG1, vlen); 149 uni_vmovups(ptr[addr_diff_states_t_l_reg], dH); 150 151 // increment address pointers 152 add(addr_ws_gates_reg, vlen_scratch); 153 add(addr_scratch_gates_reg, vlen_scratch); 154 add(addr_dhG1_reg, vlen); 155 add(addr_diff_states_t_l_reg, vlen); 156 add(addr_states_tm1_l_reg, vlen_scratch); 157 add(addr_scratch_cell_reg, vlen_scratch); 158 inc_regs(vlen); 159 160 // increment loop counter 161 sub(loop_cnt, vlen_scratch); 162 cmp(loop_cnt, vlen_scratch); 163 jge(vector_loop_start_label); 164 } 165 L(vector_loop_end_label); 166 167 cmp(loop_cnt, 0); 168 je(rem_loop_end_label, Xbyak::CodeGenerator::T_NEAR); 169 // Same code as above, we just use movuss for accessing inputs 170 // TODO: smarter handling of tails with Zmm -> Ymm -> Xmm -> scalar 171 L(rem_loop_start_label); 172 { 173 Xmm dG1(dG1_idx), dhG1(dhG1_idx), hG1(hG1_idx), G1(G1_idx), 174 dH(dH_idx), tmp1(tmp1_idx), h(h_idx); 175 176 to_float<src_data_t>(G1, wg_addr(1), hstate_dt_size); 177 to_float<src_data_t>(h, ptr[addr_states_tm1_l_reg], hstate_dt_size); 178 179 // compute dG1 180 uni_vmovss(dG1, G1); 181 uni_vmovss(tmp1, G1); 182 uni_vfnmadd231ps(dG1, tmp1, tmp1); // (G1 - G1^2) 183 uni_vmulss(dG1, dG1, h); 184 uni_vmovss(dhG1, ptr[addr_dhG1_reg]); 185 uni_vmulss(dG1, dG1, dhG1); // dhG1 * h * (G0 - G0^2) * dHt 186 187 // compute hG1 188 uni_vmovss(hG1, G1); 189 uni_vmulss(hG1, hG1, h); 190 191 // compute diff_states_t_l = diff_states_t_l + dhG1 * G1 192 uni_vmovss(dH, ptr[addr_diff_states_t_l_reg]); 193 uni_vfmadd231ps(dH, dhG1, G1); 194 195 // downconvert and write data 196 to_src<scratch_data_t>(sg_addr(1), dG1, hstate_dt_size); 197 to_src<scratch_data_t>( 198 ptr[addr_scratch_cell_reg], hG1, hstate_dt_size); 199 uni_vmovss(ptr[addr_diff_states_t_l_reg], dH); 200 201 // increment address pointers 202 add(addr_ws_gates_reg, scratch_dt_size); 203 add(addr_scratch_gates_reg, scratch_dt_size); 204 add(addr_dhG1_reg, hstate_dt_size); 205 add(addr_diff_states_t_l_reg, hstate_dt_size); 206 add(addr_states_tm1_l_reg, scratch_dt_size); 207 add(addr_scratch_cell_reg, scratch_dt_size); 208 inc_regs(hstate_dt_size); 209 210 // increment loop counter 211 sub(loop_cnt, scratch_dt_size); 212 cmp(loop_cnt, 0); 213 jg(rem_loop_start_label); 214 } 215 L(rem_loop_end_label); 216 217 postamble(); 218 219 init_table(vlen); 220 } 221 }; 222 223 } // namespace x64 224 } // namespace cpu 225 } // namespace impl 226 } // namespace dnnl 227 228 #endif 229