1 /*******************************************************************************
2 * Copyright 2020-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #ifndef CPU_X64_PRELU_JIT_PRELU_REDUCTION_HPP
18 #define CPU_X64_PRELU_JIT_PRELU_REDUCTION_HPP
19 
20 #include <memory>
21 
22 #include "cpu/cpu_prelu_pd.hpp"
23 #include "cpu/x64/cpu_isa_traits.hpp"
24 #include "cpu/x64/jit_generator.hpp"
25 #include "cpu/x64/utils/jit_io_helper.hpp"
26 
27 namespace dnnl {
28 namespace impl {
29 namespace cpu {
30 namespace x64 {
31 
32 class jit_prelu_reduction_kernel_t : public jit_generator {
33 public:
34     static jit_prelu_reduction_kernel_t *create(const cpu_prelu_bwd_pd_t *pd);
35 
36     struct call_params_t {
37         size_t reduction_blocks = 0;
38         const void *weights_diff_scratch = nullptr;
39         void *weights_diff = nullptr;
40         bool tail = false;
41         bool is_last_c_blk = false;
42     };
43 
44     void generate() override;
45     size_t simd_w() const;
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_prelu_reduction_kernel_t)46     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_prelu_reduction_kernel_t)
47 
48     void operator()(jit_prelu_reduction_kernel_t::call_params_t *params) {
49         jit_generator::operator()(params);
50     }
51 
52 private:
53     void load_kernel_call_params();
54     virtual size_t get_unrolling_factor(bool tail) const = 0;
55     virtual void compute_dst(int unrolling_factor, bool tail) = 0;
56     virtual void prepare_kernel_const_vars(bool tail) = 0;
57     virtual void finalize(bool tail) = 0;
58     void generate(bool tail);
59 
60     const Xbyak::Reg64 &reg_reduction_blocks_ = r8;
61     const Xbyak::Reg64 &reg_weights_diff_scratch_ = r10;
62     const Xbyak::Reg8 &reg_tail_ = r12b;
63 
64     const size_t scratchpad_c_block_offset_ = 0;
65 
66 protected:
67     jit_prelu_reduction_kernel_t(const cpu_prelu_bwd_pd_t *pd, int simd_w);
68     Xbyak::Address diff_scratch_ptr(int unrolling_group) const;
69     int reserve_vmm();
70 
71     const size_t simd_w_ = 0;
72     const data_type_t data_type_;
73     const size_t tail_size_ = 0;
74     const Xbyak::Reg64 &reg_offset_ = r9;
75     const Xbyak::Reg64 &reg_weights_diff_ = r11;
76     const Xbyak::Reg8 &reg_last_c_blk_byte_ = r13b;
77     size_t number_reserved_vmms_ = 0;
78     size_t tail_block_size_ = 0;
79     size_t c_blk_nelems_ = 0;
80 };
81 
82 template <typename Vmm>
83 class jit_uni_prelu_reduction_kernel_t : public jit_prelu_reduction_kernel_t {
84 public:
85     jit_uni_prelu_reduction_kernel_t(
86             const cpu_prelu_bwd_pd_t *pd, const cpu_isa_t &isa);
87 
88 private:
89     size_t get_unrolling_factor(bool tail) const override;
90     void prepare_kernel_const_vars(bool tail) override;
91     void finalize(bool tail) override;
92     void compute_dst(int unrolling_factor, bool tail) override;
93 
94     const cpu_isa_t isa_;
95     const bool saturation_needed_;
96     const Vmm accumulator_;
97     const Vmm tail_vmm_mask_;
98     const Vmm saturation_lower_bound_;
99     const Vmm saturation_upper_bound_;
100 
101     const Xbyak::Opmask &tail_opmask_ = k1;
102     const Xbyak::Reg64 &reg_tmp_ = r15;
103 
104     io::jit_io_helper_t<Vmm> io_;
105 };
106 
107 } // namespace x64
108 } // namespace cpu
109 } // namespace impl
110 } // namespace dnnl
111 
112 #endif
113