1 /*******************************************************************************
2 * Copyright 2020-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #ifndef CPU_X64_PRELU_JIT_PRELU_BACKWARD_KERNEL_HPP
18 #define CPU_X64_PRELU_JIT_PRELU_BACKWARD_KERNEL_HPP
19 
20 #include <map>
21 #include <utility>
22 
23 #include "cpu/cpu_prelu_pd.hpp"
24 #include "cpu/x64/cpu_isa_traits.hpp"
25 #include "cpu/x64/prelu/jit_prelu_base_kernel.hpp"
26 #include "cpu/x64/utils/jit_io_helper.hpp"
27 
28 namespace dnnl {
29 namespace impl {
30 namespace cpu {
31 namespace x64 {
32 
33 class jit_prelu_backward_kernel_t : public jit_prelu_base_kernel_t {
34 public:
35     static jit_prelu_backward_kernel_t *create(const cpu_prelu_bwd_pd_t *pd);
36 
37     struct call_params_t {
38         const void *src = nullptr, *weights = nullptr, *dst_diff = nullptr;
39         void *src_diff = nullptr, *weights_diff = nullptr;
40         size_t compute_data_size = 0u;
41     };
42 
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_prelu_backward_kernel_t)43     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_prelu_backward_kernel_t)
44 
45     void operator()(jit_prelu_backward_kernel_t::call_params_t *params) {
46         jit_generator::operator()(params);
47     }
48 
49 protected:
50     jit_prelu_backward_kernel_t(const cpu_prelu_bwd_pd_t *pd,
51             const cpu_isa_t &isa, const int vlen,
52             size_t number_vmm_single_compute);
53     Xbyak::Address data_ptr(int arg_num, size_t offt = 0);
54 
55     const cpu_prelu_bwd_pd_t *pd_;
56     const Xbyak::Reg64 &reg_weights_ = r10;
57     const Xbyak::Reg64 &reg_weights_diff_ = r11;
58 
59     const data_type_t src_dt_;
60     const data_type_t wei_dt_;
61     const data_type_t diff_src_dt_;
62     const data_type_t diff_dst_dt_;
63     const data_type_t diff_wei_dt_;
64     const size_t diff_src_block_tail_;
65     const size_t diff_wei_block_tail_;
66 
67     const Xbyak::Reg64 &reg_src_ = r12;
68     const Xbyak::Reg64 &reg_src_diff_ = r13;
69     const Xbyak::Reg64 &reg_dst_diff_ = r14;
70 
71 private:
72     bool any_tensor_bf16() const override;
73     void load_kernel_call_params() override;
74 };
75 
76 template <typename Vmm>
77 class jit_uni_prelu_backward_kernel_t : public jit_prelu_backward_kernel_t {
78 public:
79     jit_uni_prelu_backward_kernel_t(
80             const cpu_prelu_bwd_pd_t *pd, const cpu_isa_t &isa);
81     ~jit_uni_prelu_backward_kernel_t() override;
82 
83 private:
84     void prepare_kernel_const_vars() override;
85     void compute_dst(size_t unrolling_factor, bool tail) override;
86     const Xbyak::Operand &get_or_load_weights(
87             const Xbyak::Address &src_addr, const Vmm &dst_vmm, bool tail);
88     void accumulate_weights_diff(const Vmm &partial_sum_vmm, const Vmm &tmp_vmm,
89             const Xbyak::Address &dst_addr, bool tail);
90     void finalize() override;
91     std::map<data_type_t, io::io_saturation_conf_t>
92     create_saturation_vmm_map() const;
93 
94     const bool saturation_needed_diff_src_;
95     const bool saturation_needed_diff_weights_;
96 
97     const Vmm vmm_zeros_;
98     const Vmm saturation_ubound_diff_src_;
99     const Vmm saturation_ubound_diff_weights_;
100 
101     const Vmm tail_vmm_mask_;
102     const Vmm vmm_ones_;
103     const Vmm weights_const_vmm_;
104     const Vmm weights_diff_acc_vmm_;
105 
106     const Xbyak::Opmask &tail_opmask_ = k1;
107     const Xbyak::Reg64 &reg_tmp_ = r15;
108 
109     io::jit_io_multi_dt_helper_t<Vmm> io_;
110 };
111 
112 } // namespace x64
113 } // namespace cpu
114 } // namespace impl
115 } // namespace dnnl
116 
117 #endif
118