1 /******************************************************************************* 2 * Copyright 2020-2021 Intel Corporation 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 *******************************************************************************/ 16 17 #ifndef CPU_X64_LRN_JIT_LRN_AVX512_NHWC_EXECUTOR_HPP 18 #define CPU_X64_LRN_JIT_LRN_AVX512_NHWC_EXECUTOR_HPP 19 20 #include "cpu/x64/lrn/jit_avx512_common_lrn_bwd_nhwc.hpp" 21 #include "cpu/x64/lrn/jit_avx512_common_lrn_fwd_nhwc.hpp" 22 #include "cpu/x64/lrn/lrn_executor.hpp" 23 24 namespace dnnl { 25 namespace impl { 26 namespace cpu { 27 namespace x64 { 28 namespace lrn { 29 30 template <::dnnl::impl::data_type_t d_type, typename PD_T> 31 class lrn_avx512_nhwc_executor_fwd_t : public i_lrn_executor_t { 32 public: lrn_avx512_nhwc_executor_fwd_t(const PD_T * pd)33 lrn_avx512_nhwc_executor_fwd_t(const PD_T *pd) 34 : ker_(utils::make_unique< 35 lrn::jit_avx512_common_lrn_kernel_fwd_nhwc_t<d_type>>(pd->C(), 36 pd->desc()->prop_kind, 37 pd->desc()->lrn_alpha / pd->desc()->local_size, 38 pd->desc()->lrn_beta, pd->desc()->lrn_k, 39 pd->desc()->local_size)) 40 , N_(pd->MB()) 41 , C_(pd->C()) 42 , H_(pd->H()) 43 , W_(pd->W()) {} 44 45 using data_t = typename prec_traits<d_type>::type; 46 create_kernel()47 status_t create_kernel() override { return ker_->create_kernel(); } 48 execute(const exec_ctx_t & ctx) const49 status_t execute(const exec_ctx_t &ctx) const override { 50 status_t status = status::success; 51 const auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC); 52 const auto dst = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_DST, status); 53 CHECK(status); 54 const auto ws = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_WORKSPACE, status); 55 CHECK(status); 56 57 const auto ker = ker_.get(); 58 parallel_nd(N_, H_ * W_, [&](int n, int pixel_id) { 59 typename lrn::jit_avx512_common_lrn_kernel_fwd_t< 60 d_type>::jit_args_fwd_t args; 61 const auto offset = n * C_ * H_ * W_ + pixel_id * C_; 62 const auto ws_offset0 = offset * 2; 63 const auto ws_offset1 = ws_offset0 + C_; 64 65 args.src = &src[offset]; 66 args.dst = &dst[offset]; 67 args.ws0 = ws ? &ws[ws_offset0] : nullptr; 68 args.ws1 = ws ? &ws[ws_offset1] : nullptr; 69 70 (*ker)(&args); 71 }); 72 73 return status::success; 74 } 75 76 virtual ~lrn_avx512_nhwc_executor_fwd_t() = default; 77 78 private: 79 std::unique_ptr<jit_avx512_common_lrn_kernel_fwd_nhwc_t<d_type>> ker_; 80 const int N_; 81 const int C_; 82 const int H_; 83 const int W_; 84 }; 85 template <::dnnl::impl::data_type_t d_type, typename PD_T> 86 class lrn_avx512_nhwc_executor_bwd_t : public i_lrn_executor_t { 87 public: lrn_avx512_nhwc_executor_bwd_t(const PD_T * pd)88 lrn_avx512_nhwc_executor_bwd_t(const PD_T *pd) 89 : ker_ {utils::make_unique< 90 lrn::jit_avx512_common_lrn_kernel_bwd_nhwc_t<d_type>>(pd->C(), 91 pd->desc()->lrn_alpha / pd->desc()->local_size, 92 pd->desc()->lrn_beta, pd->desc()->local_size)} 93 , N_(pd->MB()) 94 , C_(pd->C()) 95 , H_(pd->H()) 96 , W_(pd->W()) {} 97 using data_t = typename prec_traits<d_type>::type; 98 create_kernel()99 status_t create_kernel() override { return ker_->create_kernel(); } 100 execute(const exec_ctx_t & ctx) const101 status_t execute(const exec_ctx_t &ctx) const override { 102 status_t status = status::success; 103 auto src = CTX_IN_MEM(data_t *, DNNL_ARG_SRC); 104 auto diff_src = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_DIFF_SRC, status); 105 CHECK(status); 106 auto diff_dst = CTX_IN_MEM(data_t *, DNNL_ARG_DIFF_DST); 107 auto ws = CTX_IN_MEM(data_t *, DNNL_ARG_WORKSPACE); 108 109 const auto ker = ker_.get(); 110 parallel_nd(N_, H_ * W_, [&](int n, int pixel_id) { 111 typename lrn::jit_avx512_common_lrn_kernel_bwd_nhwc_t< 112 d_type>::jit_args_bwd_t args; 113 const auto offset = n * C_ * H_ * W_ + pixel_id * C_; 114 const auto ws_offset0 = offset * 2; 115 const auto ws_offset1 = ws_offset0 + C_; 116 117 args.src = &src[offset]; 118 args.diff_dst = &diff_dst[offset]; 119 args.ws0 = &ws[ws_offset0]; 120 args.ws1 = &ws[ws_offset1]; 121 args.diff_src = &diff_src[offset]; 122 123 (*ker)(&args); 124 }); 125 126 return status::success; 127 } 128 129 virtual ~lrn_avx512_nhwc_executor_bwd_t() = default; 130 131 private: 132 std::unique_ptr<jit_avx512_common_lrn_kernel_bwd_nhwc_t<d_type>> ker_; 133 const int N_; 134 const int C_; 135 const int H_; 136 const int W_; 137 }; 138 139 } // namespace lrn 140 } // namespace x64 141 } // namespace cpu 142 } // namespace impl 143 } // namespace dnnl 144 145 #endif 146