1 /******************************************************************************* 2 * Copyright 2020-2021 Intel Corporation 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 *******************************************************************************/ 16 17 #ifndef CPU_X64_PRELU_JIT_PRELU_BASE_KERNEL_HPP_ 18 #define CPU_X64_PRELU_JIT_PRELU_BASE_KERNEL_HPP_ 19 20 #include "cpu/x64/jit_generator.hpp" 21 #include "cpu/x64/prelu/jit_prelu_utils.hpp" 22 23 namespace dnnl { 24 namespace impl { 25 namespace cpu { 26 namespace x64 { 27 28 class jit_prelu_base_kernel_t : public jit_generator { 29 public: 30 jit_prelu_base_kernel_t(const cpu_isa_t &isa, const int vlen, 31 const prelu::bcast &bcast, const memory_desc_wrapper &tensor_md, 32 const size_t number_vmm_single_compute); 33 34 size_t simd_w() const noexcept; 35 prelu::bcast get_bcast() const noexcept; 36 37 protected: 38 int reserve_vmm(); 39 int get_compute_vmm(size_t base_idx, size_t unroll_group) const; 40 41 size_t get_number_reserved_vmms() const noexcept; 42 43 const cpu_isa_t isa_; 44 const size_t simd_w_ = 0; 45 const prelu::bcast bcast_ = prelu::bcast::unsupported; 46 const size_t tail_size_ = 0u; 47 const Xbyak::Reg64 ®_data_size_ = r8; 48 const Xbyak::Reg64 ®_offset_ = r9; 49 50 private: 51 void generate() override; 52 virtual bool any_tensor_bf16() const = 0; 53 virtual void load_kernel_call_params() = 0; 54 virtual void prepare_kernel_const_vars() = 0; 55 virtual void compute_dst(size_t unrolling_factor, bool tail) = 0; 56 virtual void finalize() = 0; 57 size_t calc_unrolling_factor() const noexcept; 58 size_t calc_tail_size(const memory_desc_wrapper &tensor_md) const noexcept; 59 const memory_desc_wrapper tensor_md_; 60 const size_t number_vmm_single_compute_ = 0; 61 size_t number_reserved_vmms_ = 0; 62 }; 63 64 } // namespace x64 65 } // namespace cpu 66 } // namespace impl 67 } // namespace dnnl 68 69 #endif 70