1 /*******************************************************************************
2 * Copyright 2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "cpu/x64/rnn/jit_diff_weights_peephole.hpp"
18 #include "common/c_types_map.hpp"
19 #include "cpu/rnn/rnn_utils.hpp"
20
21 namespace dnnl {
22 namespace impl {
23 namespace cpu {
24 namespace x64 {
25
jit_diff_weights_peephole_t(const rnn_utils::rnn_conf_t & rnn,const dim_t dhc_block_size)26 jit_diff_weights_peephole_t::jit_diff_weights_peephole_t(
27 const rnn_utils::rnn_conf_t &rnn, const dim_t dhc_block_size)
28 : c_states_dt_(data_type::f32)
29 , scratch_dt_(rnn.is_bf16() ? data_type::bf16 : data_type::f32)
30 , dst_dt_(data_type::f32)
31 , compute_block_size_(dhc_block_size)
32 , tail_size_(dhc_block_size % simd_w_)
33 , io_(this, mayiuse(avx512_core_bf16) ? avx512_core_bf16 : avx512_core,
34 {c_states_dt_, scratch_dt_, dst_dt_}, {},
35 io::io_tail_conf_t {static_cast<std::size_t>(simd_w_),
36 static_cast<std::size_t>(tail_size_), tail_opmask_, 0,
37 reg_tmp_}) {}
38
generate()39 void jit_diff_weights_peephole_t::generate() {
40 preamble();
41 load_addresses();
42 init();
43 compute_loop();
44 postamble();
45 }
46
47 #define PARAM_OFF(x) offsetof(jit_diff_weights_peephole_t::call_params_t, x)
48
load_addresses()49 void jit_diff_weights_peephole_t::load_addresses() {
50 mov(reg_c_states_, ptr[abi_param1 + PARAM_OFF(c_states)]);
51 mov(reg_scratch_gates_, ptr[abi_param1 + PARAM_OFF(scratch_gates)]);
52 mov(reg_dst_, ptr[abi_param1 + PARAM_OFF(dst)]);
53 }
54
55 #undef PARAM_OFF
56
init()57 void jit_diff_weights_peephole_t::init() {
58 if (tail_size_) { io_.prepare_tail_mask(); }
59 }
60
compute_loop()61 void jit_diff_weights_peephole_t::compute_loop() {
62
63 Xbyak::Label unroll_loop, unroll_loop_tail;
64
65 mov(loop_cnt_, compute_block_size_);
66 xor_(reg_offset_, reg_offset_);
67
68 const size_t offt_max = max_unrolling * simd_w_;
69 const size_t full_unroling_steps = compute_block_size_ / offt_max;
70
71 if (full_unroling_steps) {
72 L(unroll_loop);
73 {
74 cmp(loop_cnt_, offt_max);
75 jl(unroll_loop_tail, T_NEAR);
76
77 compute_dst(max_unrolling, false /*tail*/);
78 sub(loop_cnt_, offt_max);
79 add(reg_offset_, offt_max);
80 jmp(unroll_loop);
81 }
82 }
83
84 const size_t full_blocks_left = (compute_block_size_ - tail_size_
85 - (full_unroling_steps * offt_max))
86 / simd_w_;
87
88 L(unroll_loop_tail);
89 {
90 if (full_blocks_left) {
91 compute_dst(full_blocks_left, false /*tail*/);
92 if (tail_size_) {
93 const size_t offt = full_blocks_left * simd_w_;
94 add(reg_offset_, offt);
95 }
96 }
97 if (tail_size_) { compute_dst(1u /*unrolling factor*/, true /*tail*/); }
98 }
99 }
100
compute_dst(size_t unrolling_factor,bool tail)101 void jit_diff_weights_peephole_t::compute_dst(
102 size_t unrolling_factor, bool tail) {
103
104 static constexpr dim_t number_vmm_single_compute = 2;
105
106 const auto get_compute_zmm = [=](size_t base_idx, size_t unroll_group) {
107 return Xbyak::Zmm(base_idx + unroll_group * number_vmm_single_compute);
108 };
109
110 const auto get_addr = [&](const Xbyak::Reg64 ®_base, const dim_t offt,
111 const data_type_t dt) {
112 const auto dt_size = types::data_type_size(dt);
113 return ptr[reg_base + reg_offset_ * dt_size + offt * dt_size];
114 };
115
116 static constexpr size_t dst_idx = 0;
117 static constexpr size_t scratch_idx = 1;
118 const auto io_dst = io_.at(dst_dt_);
119 const auto io_scratch = io_.at(scratch_dt_);
120
121 for (size_t unroll_group = 0; unroll_group < unrolling_factor;
122 ++unroll_group) {
123
124 const auto dst_zmm = get_compute_zmm(dst_idx, unroll_group);
125 const auto scratch_zmm = get_compute_zmm(scratch_idx, unroll_group);
126 const auto unroll_offset = unroll_group * simd_w_;
127 const auto dst_addr = get_addr(reg_dst_, unroll_offset, dst_dt_);
128 io_dst->load(dst_addr, dst_zmm, tail);
129 io_scratch->load(
130 get_addr(reg_scratch_gates_, unroll_offset, scratch_dt_),
131 scratch_zmm, tail);
132 const auto dst_zmm_masked = tail ? dst_zmm | tail_opmask_ : dst_zmm;
133 uni_vfmadd231ps(dst_zmm_masked, scratch_zmm,
134 get_addr(reg_c_states_, unroll_offset, c_states_dt_));
135 io_dst->store(dst_zmm, dst_addr, tail);
136 }
137 }
138
139 } // namespace x64
140 } // namespace cpu
141 } // namespace impl
142 } // namespace dnnl
143