1 /*******************************************************************************
2 * Copyright 2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include "cpu/x64/rnn/jit_diff_weights_peephole.hpp"
18 #include "common/c_types_map.hpp"
19 #include "cpu/rnn/rnn_utils.hpp"
20 
21 namespace dnnl {
22 namespace impl {
23 namespace cpu {
24 namespace x64 {
25 
jit_diff_weights_peephole_t(const rnn_utils::rnn_conf_t & rnn,const dim_t dhc_block_size)26 jit_diff_weights_peephole_t::jit_diff_weights_peephole_t(
27         const rnn_utils::rnn_conf_t &rnn, const dim_t dhc_block_size)
28     : c_states_dt_(data_type::f32)
29     , scratch_dt_(rnn.is_bf16() ? data_type::bf16 : data_type::f32)
30     , dst_dt_(data_type::f32)
31     , compute_block_size_(dhc_block_size)
32     , tail_size_(dhc_block_size % simd_w_)
33     , io_(this, mayiuse(avx512_core_bf16) ? avx512_core_bf16 : avx512_core,
34               {c_states_dt_, scratch_dt_, dst_dt_}, {},
35               io::io_tail_conf_t {static_cast<std::size_t>(simd_w_),
36                       static_cast<std::size_t>(tail_size_), tail_opmask_, 0,
37                       reg_tmp_}) {}
38 
generate()39 void jit_diff_weights_peephole_t::generate() {
40     preamble();
41     load_addresses();
42     init();
43     compute_loop();
44     postamble();
45 }
46 
47 #define PARAM_OFF(x) offsetof(jit_diff_weights_peephole_t::call_params_t, x)
48 
load_addresses()49 void jit_diff_weights_peephole_t::load_addresses() {
50     mov(reg_c_states_, ptr[abi_param1 + PARAM_OFF(c_states)]);
51     mov(reg_scratch_gates_, ptr[abi_param1 + PARAM_OFF(scratch_gates)]);
52     mov(reg_dst_, ptr[abi_param1 + PARAM_OFF(dst)]);
53 }
54 
55 #undef PARAM_OFF
56 
init()57 void jit_diff_weights_peephole_t::init() {
58     if (tail_size_) { io_.prepare_tail_mask(); }
59 }
60 
compute_loop()61 void jit_diff_weights_peephole_t::compute_loop() {
62 
63     Xbyak::Label unroll_loop, unroll_loop_tail;
64 
65     mov(loop_cnt_, compute_block_size_);
66     xor_(reg_offset_, reg_offset_);
67 
68     const size_t offt_max = max_unrolling * simd_w_;
69     const size_t full_unroling_steps = compute_block_size_ / offt_max;
70 
71     if (full_unroling_steps) {
72         L(unroll_loop);
73         {
74             cmp(loop_cnt_, offt_max);
75             jl(unroll_loop_tail, T_NEAR);
76 
77             compute_dst(max_unrolling, false /*tail*/);
78             sub(loop_cnt_, offt_max);
79             add(reg_offset_, offt_max);
80             jmp(unroll_loop);
81         }
82     }
83 
84     const size_t full_blocks_left = (compute_block_size_ - tail_size_
85                                             - (full_unroling_steps * offt_max))
86             / simd_w_;
87 
88     L(unroll_loop_tail);
89     {
90         if (full_blocks_left) {
91             compute_dst(full_blocks_left, false /*tail*/);
92             if (tail_size_) {
93                 const size_t offt = full_blocks_left * simd_w_;
94                 add(reg_offset_, offt);
95             }
96         }
97         if (tail_size_) { compute_dst(1u /*unrolling factor*/, true /*tail*/); }
98     }
99 }
100 
compute_dst(size_t unrolling_factor,bool tail)101 void jit_diff_weights_peephole_t::compute_dst(
102         size_t unrolling_factor, bool tail) {
103 
104     static constexpr dim_t number_vmm_single_compute = 2;
105 
106     const auto get_compute_zmm = [=](size_t base_idx, size_t unroll_group) {
107         return Xbyak::Zmm(base_idx + unroll_group * number_vmm_single_compute);
108     };
109 
110     const auto get_addr = [&](const Xbyak::Reg64 &reg_base, const dim_t offt,
111                                   const data_type_t dt) {
112         const auto dt_size = types::data_type_size(dt);
113         return ptr[reg_base + reg_offset_ * dt_size + offt * dt_size];
114     };
115 
116     static constexpr size_t dst_idx = 0;
117     static constexpr size_t scratch_idx = 1;
118     const auto io_dst = io_.at(dst_dt_);
119     const auto io_scratch = io_.at(scratch_dt_);
120 
121     for (size_t unroll_group = 0; unroll_group < unrolling_factor;
122             ++unroll_group) {
123 
124         const auto dst_zmm = get_compute_zmm(dst_idx, unroll_group);
125         const auto scratch_zmm = get_compute_zmm(scratch_idx, unroll_group);
126         const auto unroll_offset = unroll_group * simd_w_;
127         const auto dst_addr = get_addr(reg_dst_, unroll_offset, dst_dt_);
128         io_dst->load(dst_addr, dst_zmm, tail);
129         io_scratch->load(
130                 get_addr(reg_scratch_gates_, unroll_offset, scratch_dt_),
131                 scratch_zmm, tail);
132         const auto dst_zmm_masked = tail ? dst_zmm | tail_opmask_ : dst_zmm;
133         uni_vfmadd231ps(dst_zmm_masked, scratch_zmm,
134                 get_addr(reg_c_states_, unroll_offset, c_states_dt_));
135         io_dst->store(dst_zmm, dst_addr, tail);
136     }
137 }
138 
139 } // namespace x64
140 } // namespace cpu
141 } // namespace impl
142 } // namespace dnnl
143