1 /*******************************************************************************
2 * Copyright 2020-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #ifndef CPU_X64_PRELU_JIT_PRELU_FORWARD_KERNEL_HPP
18 #define CPU_X64_PRELU_JIT_PRELU_FORWARD_KERNEL_HPP
19 
20 #include <map>
21 
22 #include "cpu/cpu_prelu_pd.hpp"
23 #include "cpu/x64/cpu_isa_traits.hpp"
24 #include "cpu/x64/prelu/jit_prelu_base_kernel.hpp"
25 #include "cpu/x64/utils/jit_io_helper.hpp"
26 
27 namespace dnnl {
28 namespace impl {
29 namespace cpu {
30 namespace x64 {
31 
32 class jit_prelu_forward_kernel_t : public jit_prelu_base_kernel_t {
33 public:
34     static jit_prelu_forward_kernel_t *create(const cpu_prelu_fwd_pd_t *pd);
35 
36     struct call_params_t {
37         const void *src = nullptr, *weights = nullptr, *dst = nullptr;
38         size_t compute_data_size = 0u;
39     };
40 
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_prelu_forward_kernel_t)41     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_prelu_forward_kernel_t)
42 
43     void operator()(jit_prelu_forward_kernel_t::call_params_t *params) {
44         jit_generator::operator()(params);
45     }
46 
47 protected:
48     const data_type_t src_dt_;
49     const data_type_t wei_dt_;
50     const data_type_t dst_dt_;
51     const size_t dst_tail_block_;
52 
53     jit_prelu_forward_kernel_t(const cpu_prelu_fwd_pd_t *pd,
54             const cpu_isa_t &isa, const int vlen,
55             const size_t number_vmm_single_compute);
56     Xbyak::Address data_ptr(int arg_num, size_t offt = 0);
57 
58 private:
59     bool any_tensor_bf16() const override;
60     void load_kernel_call_params() override;
finalize()61     void finalize() override {}
62 
63 protected:
64     const Xbyak::Reg64 &reg_src_ = r10;
65     const Xbyak::Reg64 &reg_dst_ = r11;
66     const Xbyak::Reg64 &reg_weights_ = r12;
67     const cpu_prelu_fwd_pd_t *pd_;
68 };
69 
70 template <typename Vmm>
71 class jit_uni_prelu_forward_kernel_t : public jit_prelu_forward_kernel_t {
72 public:
73     jit_uni_prelu_forward_kernel_t(
74             const cpu_prelu_fwd_pd_t *pd, const cpu_isa_t &isa);
75     ~jit_uni_prelu_forward_kernel_t() override;
76 
77 private:
78     using jit_generator::uni_vfmadd132ps;
79 
80     void prepare_kernel_const_vars() override;
81     void compute_dst(size_t unrolling_factor, bool tail) override;
82     bool can_load_wei_from_addr_directly(bool tail) const noexcept;
83 
84     Vmm get_or_load_weights(
85             const Xbyak::Address &src_addr, const Vmm &dst_vmm, bool tail);
86     void uni_vfmadd132ps(
87             const Vmm &x1, const Vmm &x2, const Xbyak::Operand &op, bool tail);
88     std::map<data_type_t, io::io_saturation_conf_t>
89     create_saturation_vmm_map() const;
90 
91     const bool saturation_needed_ = false;
92     const Vmm vmm_zeros_;
93     const Vmm dst_saturate_ubound_;
94     const Vmm tail_vmm_mask_;
95     const Vmm weights_const_vmm_;
96     const size_t number_vmm_single_compute_ = 0;
97     const Xbyak::Opmask &tail_opmask_ = k1;
98     const Xbyak::Reg64 &reg_tmp_ = r15;
99 
100     io::jit_io_multi_dt_helper_t<Vmm> io_;
101 };
102 
103 } // namespace x64
104 } // namespace cpu
105 } // namespace impl
106 } // namespace dnnl
107 
108 #endif
109