1 /******************************************************************************* 2 * Copyright 2020-2021 Intel Corporation 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 *******************************************************************************/ 16 17 #ifndef CPU_X64_PRELU_JIT_PRELU_FORWARD_KERNEL_HPP 18 #define CPU_X64_PRELU_JIT_PRELU_FORWARD_KERNEL_HPP 19 20 #include <map> 21 22 #include "cpu/cpu_prelu_pd.hpp" 23 #include "cpu/x64/cpu_isa_traits.hpp" 24 #include "cpu/x64/prelu/jit_prelu_base_kernel.hpp" 25 #include "cpu/x64/utils/jit_io_helper.hpp" 26 27 namespace dnnl { 28 namespace impl { 29 namespace cpu { 30 namespace x64 { 31 32 class jit_prelu_forward_kernel_t : public jit_prelu_base_kernel_t { 33 public: 34 static jit_prelu_forward_kernel_t *create(const cpu_prelu_fwd_pd_t *pd); 35 36 struct call_params_t { 37 const void *src = nullptr, *weights = nullptr, *dst = nullptr; 38 size_t compute_data_size = 0u; 39 }; 40 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_prelu_forward_kernel_t)41 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_prelu_forward_kernel_t) 42 43 void operator()(jit_prelu_forward_kernel_t::call_params_t *params) { 44 jit_generator::operator()(params); 45 } 46 47 protected: 48 const data_type_t src_dt_; 49 const data_type_t wei_dt_; 50 const data_type_t dst_dt_; 51 const size_t dst_tail_block_; 52 53 jit_prelu_forward_kernel_t(const cpu_prelu_fwd_pd_t *pd, 54 const cpu_isa_t &isa, const int vlen, 55 const size_t number_vmm_single_compute); 56 Xbyak::Address data_ptr(int arg_num, size_t offt = 0); 57 58 private: 59 bool any_tensor_bf16() const override; 60 void load_kernel_call_params() override; finalize()61 void finalize() override {} 62 63 protected: 64 const Xbyak::Reg64 ®_src_ = r10; 65 const Xbyak::Reg64 ®_dst_ = r11; 66 const Xbyak::Reg64 ®_weights_ = r12; 67 const cpu_prelu_fwd_pd_t *pd_; 68 }; 69 70 template <typename Vmm> 71 class jit_uni_prelu_forward_kernel_t : public jit_prelu_forward_kernel_t { 72 public: 73 jit_uni_prelu_forward_kernel_t( 74 const cpu_prelu_fwd_pd_t *pd, const cpu_isa_t &isa); 75 ~jit_uni_prelu_forward_kernel_t() override; 76 77 private: 78 using jit_generator::uni_vfmadd132ps; 79 80 void prepare_kernel_const_vars() override; 81 void compute_dst(size_t unrolling_factor, bool tail) override; 82 bool can_load_wei_from_addr_directly(bool tail) const noexcept; 83 84 Vmm get_or_load_weights( 85 const Xbyak::Address &src_addr, const Vmm &dst_vmm, bool tail); 86 void uni_vfmadd132ps( 87 const Vmm &x1, const Vmm &x2, const Xbyak::Operand &op, bool tail); 88 std::map<data_type_t, io::io_saturation_conf_t> 89 create_saturation_vmm_map() const; 90 91 const bool saturation_needed_ = false; 92 const Vmm vmm_zeros_; 93 const Vmm dst_saturate_ubound_; 94 const Vmm tail_vmm_mask_; 95 const Vmm weights_const_vmm_; 96 const size_t number_vmm_single_compute_ = 0; 97 const Xbyak::Opmask &tail_opmask_ = k1; 98 const Xbyak::Reg64 ®_tmp_ = r15; 99 100 io::jit_io_multi_dt_helper_t<Vmm> io_; 101 }; 102 103 } // namespace x64 104 } // namespace cpu 105 } // namespace impl 106 } // namespace dnnl 107 108 #endif 109