1 /******************************************************************************* 2 * Copyright 2021 Intel Corporation 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 *******************************************************************************/ 16 17 #ifndef CPU_X64_JIT_BRGEMM_1X1_CONV_HPP 18 #define CPU_X64_JIT_BRGEMM_1X1_CONV_HPP 19 20 #include "common/c_types_map.hpp" 21 #include "common/dnnl_thread.hpp" 22 #include "common/memory_tracking.hpp" 23 #include "common/primitive.hpp" 24 #include "common/utils.hpp" 25 26 #include "cpu/cpu_convolution_pd.hpp" 27 #include "cpu/platform.hpp" 28 29 #include "cpu/x64/brgemm/brgemm.hpp" 30 #include "cpu/x64/cpu_barrier.hpp" 31 #include "cpu/x64/cpu_reducer.hpp" 32 #include "cpu/x64/jit_brgemm_conv_utils.hpp" 33 #include "cpu/x64/jit_brgemm_post_ops.hpp" 34 35 namespace dnnl { 36 namespace impl { 37 namespace cpu { 38 namespace x64 { 39 40 template <cpu_isa_t isa, impl::data_type_t src_type, 41 impl::data_type_t wei_type = src_type, 42 impl::data_type_t dst_type = src_type> 43 struct brgemm_1x1_convolution_fwd_t : public primitive_t { 44 struct pd_t : public cpu_convolution_fwd_pd_t { pd_tdnnl::impl::cpu::x64::brgemm_1x1_convolution_fwd_t::pd_t45 pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, 46 const typename pd_t::base_class *hint_fwd_pd) 47 : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd) 48 , attr_(attr) 49 , with_sum(false) 50 , sum_scale(0) {} 51 52 DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("brgconv_1x1:", isa, ""), 53 brgemm_1x1_convolution_fwd_t); 54 55 status_t init(engine_t *engine); 56 57 const primitive_attr_t *attr_; 58 brgemm_t brgs_[16]; 59 bool with_sum; 60 float sum_scale; 61 62 jit_brgemm_conv_conf_t jcp_; 63 }; 64 brgemm_1x1_convolution_fwd_tdnnl::impl::cpu::x64::brgemm_1x1_convolution_fwd_t65 brgemm_1x1_convolution_fwd_t(const pd_t *apd) 66 : primitive_t(apd), bias_d(pd()->weights_md(1)) {} 67 ~brgemm_1x1_convolution_fwd_tdnnl::impl::cpu::x64::brgemm_1x1_convolution_fwd_t68 ~brgemm_1x1_convolution_fwd_t() {} 69 70 typedef typename prec_traits<src_type>::type src_data_t; 71 typedef typename prec_traits<wei_type>::type wei_data_t; 72 typedef typename prec_traits<dst_type>::type dst_data_t; 73 executednnl::impl::cpu::x64::brgemm_1x1_convolution_fwd_t74 status_t execute(const exec_ctx_t &ctx) const override { 75 execute_forward_all(ctx); 76 77 if (pd()->wants_zero_pad_dst()) ctx.memory(DNNL_ARG_DST)->zero_pad(ctx); 78 79 return status::success; 80 } 81 82 protected: 83 status_t init(engine_t *engine) override; 84 85 private: 86 // brgemm convolution execution context 87 struct brgemm_exec_ctx_t { brgemm_exec_ctx_tdnnl::impl::cpu::x64::brgemm_1x1_convolution_fwd_t::brgemm_exec_ctx_t88 brgemm_exec_ctx_t(const exec_ctx_t &ctx, const pd_t *pd) 89 : src(CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC)) 90 , weights(CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS)) 91 , bias(CTX_IN_MEM(const char *, DNNL_ARG_BIAS)) 92 , dst(CTX_OUT_MEM(dst_data_t *, DNNL_ARG_DST)) 93 , post_ops_binary_rhs_arg_vec(binary_injector::prepare_binary_args( 94 pd->attr()->post_ops_, ctx)) {} 95 const src_data_t *const __restrict src; 96 const wei_data_t *const __restrict weights; 97 const char *const __restrict bias; 98 dst_data_t *const __restrict dst; 99 const std::vector<const void *> post_ops_binary_rhs_arg_vec; 100 }; 101 102 void exec_ker(const brgemm_exec_ctx_t &brgemm_ctx, int ithr, 103 brgemm_batch_element_t *const __restrict brg_batch, 104 char *const c_buffer, int g, int n, int ocb, int od, int oh, int ow, 105 int icc) const; 106 void execute_forward_all(const exec_ctx_t &ctx) const; pddnnl::impl::cpu::x64::brgemm_1x1_convolution_fwd_t107 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } 108 get_brg_idxdnnl::impl::cpu::x64::brgemm_1x1_convolution_fwd_t109 static int get_brg_idx(bool do_initialization, int is_M_tail, 110 bool is_N_tail, bool is_K_tail) { 111 return (((int)do_initialization * 2 + (int)is_M_tail) * 2 112 + (int)is_N_tail) 113 * 2 114 + (int)is_K_tail; 115 } 116 get_ker_po_idxdnnl::impl::cpu::x64::brgemm_1x1_convolution_fwd_t117 static int get_ker_po_idx(int is_M_tail, bool is_N_tail) { 118 return (int)is_M_tail * 2 + (int)is_N_tail; 119 } 120 121 std::unique_ptr<brgemm_kernel_t> brg_kernels_[16]; 122 std::unique_ptr<jit_brgemm_kernel_post_ops> kernels_po_[4]; 123 124 const memory_desc_wrapper bias_d; 125 126 int ID, IH, IW, OD, OH, OW, SD, SH, SW; 127 size_t bia_dsz, acc_dsz, src_dsz, wei_dsz; 128 bool need_postwork; 129 int ic_chunks; 130 // const variables used for address calculations 131 dim_t src_w_sz, src_h_sz, src_d_sz, dst_w_sz, dst_h_sz, dst_d_sz, wei_oc_sz, 132 wei_ic_sz, wei_ocb_sz; 133 }; 134 135 } // namespace x64 136 } // namespace cpu 137 } // namespace impl 138 } // namespace dnnl 139 140 #endif 141 142 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s 143