1 /*******************************************************************************
2 * Copyright 2020-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #ifndef CPU_X64_PRELU_JIT_PRELU_BASE_KERNEL_HPP_
18 #define CPU_X64_PRELU_JIT_PRELU_BASE_KERNEL_HPP_
19 
20 #include "cpu/x64/jit_generator.hpp"
21 #include "cpu/x64/prelu/jit_prelu_utils.hpp"
22 
23 namespace dnnl {
24 namespace impl {
25 namespace cpu {
26 namespace x64 {
27 
28 class jit_prelu_base_kernel_t : public jit_generator {
29 public:
30     jit_prelu_base_kernel_t(const cpu_isa_t &isa, const int vlen,
31             const prelu::bcast &bcast, const memory_desc_wrapper &tensor_md,
32             const size_t number_vmm_single_compute);
33 
34     size_t simd_w() const noexcept;
35     prelu::bcast get_bcast() const noexcept;
36 
37 protected:
38     int reserve_vmm();
39     int get_compute_vmm(size_t base_idx, size_t unroll_group) const;
40 
41     size_t get_number_reserved_vmms() const noexcept;
42 
43     const cpu_isa_t isa_;
44     const size_t simd_w_ = 0;
45     const prelu::bcast bcast_ = prelu::bcast::unsupported;
46     const size_t tail_size_ = 0u;
47     const Xbyak::Reg64 &reg_data_size_ = r8;
48     const Xbyak::Reg64 &reg_offset_ = r9;
49 
50 private:
51     void generate() override;
52     virtual bool any_tensor_bf16() const = 0;
53     virtual void load_kernel_call_params() = 0;
54     virtual void prepare_kernel_const_vars() = 0;
55     virtual void compute_dst(size_t unrolling_factor, bool tail) = 0;
56     virtual void finalize() = 0;
57     size_t calc_unrolling_factor() const noexcept;
58     size_t calc_tail_size(const memory_desc_wrapper &tensor_md) const noexcept;
59     const memory_desc_wrapper tensor_md_;
60     const size_t number_vmm_single_compute_ = 0;
61     size_t number_reserved_vmms_ = 0;
62 };
63 
64 } // namespace x64
65 } // namespace cpu
66 } // namespace impl
67 } // namespace dnnl
68 
69 #endif
70