1 /*******************************************************************************
2 * Copyright 2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #ifndef CPU_X64_JIT_BRGEMM_1X1_CONV_HPP
18 #define CPU_X64_JIT_BRGEMM_1X1_CONV_HPP
19 
20 #include "common/c_types_map.hpp"
21 #include "common/dnnl_thread.hpp"
22 #include "common/memory_tracking.hpp"
23 #include "common/primitive.hpp"
24 #include "common/utils.hpp"
25 
26 #include "cpu/cpu_convolution_pd.hpp"
27 #include "cpu/platform.hpp"
28 
29 #include "cpu/x64/brgemm/brgemm.hpp"
30 #include "cpu/x64/cpu_barrier.hpp"
31 #include "cpu/x64/cpu_reducer.hpp"
32 #include "cpu/x64/jit_brgemm_conv_utils.hpp"
33 #include "cpu/x64/jit_brgemm_post_ops.hpp"
34 
35 namespace dnnl {
36 namespace impl {
37 namespace cpu {
38 namespace x64 {
39 
40 template <cpu_isa_t isa, impl::data_type_t src_type,
41         impl::data_type_t wei_type = src_type,
42         impl::data_type_t dst_type = src_type>
43 struct brgemm_1x1_convolution_fwd_t : public primitive_t {
44     struct pd_t : public cpu_convolution_fwd_pd_t {
pd_tdnnl::impl::cpu::x64::brgemm_1x1_convolution_fwd_t::pd_t45         pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
46                 const typename pd_t::base_class *hint_fwd_pd)
47             : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd)
48             , attr_(attr)
49             , with_sum(false)
50             , sum_scale(0) {}
51 
52         DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("brgconv_1x1:", isa, ""),
53                 brgemm_1x1_convolution_fwd_t);
54 
55         status_t init(engine_t *engine);
56 
57         const primitive_attr_t *attr_;
58         brgemm_t brgs_[16];
59         bool with_sum;
60         float sum_scale;
61 
62         jit_brgemm_conv_conf_t jcp_;
63     };
64 
brgemm_1x1_convolution_fwd_tdnnl::impl::cpu::x64::brgemm_1x1_convolution_fwd_t65     brgemm_1x1_convolution_fwd_t(const pd_t *apd)
66         : primitive_t(apd), bias_d(pd()->weights_md(1)) {}
67 
~brgemm_1x1_convolution_fwd_tdnnl::impl::cpu::x64::brgemm_1x1_convolution_fwd_t68     ~brgemm_1x1_convolution_fwd_t() {}
69 
70     typedef typename prec_traits<src_type>::type src_data_t;
71     typedef typename prec_traits<wei_type>::type wei_data_t;
72     typedef typename prec_traits<dst_type>::type dst_data_t;
73 
executednnl::impl::cpu::x64::brgemm_1x1_convolution_fwd_t74     status_t execute(const exec_ctx_t &ctx) const override {
75         execute_forward_all(ctx);
76 
77         if (pd()->wants_zero_pad_dst()) ctx.memory(DNNL_ARG_DST)->zero_pad(ctx);
78 
79         return status::success;
80     }
81 
82 protected:
83     status_t init(engine_t *engine) override;
84 
85 private:
86     //  brgemm convolution execution context
87     struct brgemm_exec_ctx_t {
brgemm_exec_ctx_tdnnl::impl::cpu::x64::brgemm_1x1_convolution_fwd_t::brgemm_exec_ctx_t88         brgemm_exec_ctx_t(const exec_ctx_t &ctx, const pd_t *pd)
89             : src(CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC))
90             , weights(CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS))
91             , bias(CTX_IN_MEM(const char *, DNNL_ARG_BIAS))
92             , dst(CTX_OUT_MEM(dst_data_t *, DNNL_ARG_DST))
93             , post_ops_binary_rhs_arg_vec(binary_injector::prepare_binary_args(
94                       pd->attr()->post_ops_, ctx)) {}
95         const src_data_t *const __restrict src;
96         const wei_data_t *const __restrict weights;
97         const char *const __restrict bias;
98         dst_data_t *const __restrict dst;
99         const std::vector<const void *> post_ops_binary_rhs_arg_vec;
100     };
101 
102     void exec_ker(const brgemm_exec_ctx_t &brgemm_ctx, int ithr,
103             brgemm_batch_element_t *const __restrict brg_batch,
104             char *const c_buffer, int g, int n, int ocb, int od, int oh, int ow,
105             int icc) const;
106     void execute_forward_all(const exec_ctx_t &ctx) const;
pddnnl::impl::cpu::x64::brgemm_1x1_convolution_fwd_t107     const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
108 
get_brg_idxdnnl::impl::cpu::x64::brgemm_1x1_convolution_fwd_t109     static int get_brg_idx(bool do_initialization, int is_M_tail,
110             bool is_N_tail, bool is_K_tail) {
111         return (((int)do_initialization * 2 + (int)is_M_tail) * 2
112                        + (int)is_N_tail)
113                 * 2
114                 + (int)is_K_tail;
115     }
116 
get_ker_po_idxdnnl::impl::cpu::x64::brgemm_1x1_convolution_fwd_t117     static int get_ker_po_idx(int is_M_tail, bool is_N_tail) {
118         return (int)is_M_tail * 2 + (int)is_N_tail;
119     }
120 
121     std::unique_ptr<brgemm_kernel_t> brg_kernels_[16];
122     std::unique_ptr<jit_brgemm_kernel_post_ops> kernels_po_[4];
123 
124     const memory_desc_wrapper bias_d;
125 
126     int ID, IH, IW, OD, OH, OW, SD, SH, SW;
127     size_t bia_dsz, acc_dsz, src_dsz, wei_dsz;
128     bool need_postwork;
129     int ic_chunks;
130     // const variables used for address calculations
131     dim_t src_w_sz, src_h_sz, src_d_sz, dst_w_sz, dst_h_sz, dst_d_sz, wei_oc_sz,
132             wei_ic_sz, wei_ocb_sz;
133 };
134 
135 } // namespace x64
136 } // namespace cpu
137 } // namespace impl
138 } // namespace dnnl
139 
140 #endif
141 
142 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
143