1 /*******************************************************************************
2 * Copyright 2020-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #ifndef CPU_X64_LRN_JIT_LRN_AVX512_NHWC_EXECUTOR_HPP
18 #define CPU_X64_LRN_JIT_LRN_AVX512_NHWC_EXECUTOR_HPP
19 
20 #include "cpu/x64/lrn/jit_avx512_common_lrn_bwd_nhwc.hpp"
21 #include "cpu/x64/lrn/jit_avx512_common_lrn_fwd_nhwc.hpp"
22 #include "cpu/x64/lrn/lrn_executor.hpp"
23 
24 namespace dnnl {
25 namespace impl {
26 namespace cpu {
27 namespace x64 {
28 namespace lrn {
29 
30 template <::dnnl::impl::data_type_t d_type, typename PD_T>
31 class lrn_avx512_nhwc_executor_fwd_t : public i_lrn_executor_t {
32 public:
lrn_avx512_nhwc_executor_fwd_t(const PD_T * pd)33     lrn_avx512_nhwc_executor_fwd_t(const PD_T *pd)
34         : ker_(utils::make_unique<
35                 lrn::jit_avx512_common_lrn_kernel_fwd_nhwc_t<d_type>>(pd->C(),
36                 pd->desc()->prop_kind,
37                 pd->desc()->lrn_alpha / pd->desc()->local_size,
38                 pd->desc()->lrn_beta, pd->desc()->lrn_k,
39                 pd->desc()->local_size))
40         , N_(pd->MB())
41         , C_(pd->C())
42         , H_(pd->H())
43         , W_(pd->W()) {}
44 
45     using data_t = typename prec_traits<d_type>::type;
46 
create_kernel()47     status_t create_kernel() override { return ker_->create_kernel(); }
48 
execute(const exec_ctx_t & ctx) const49     status_t execute(const exec_ctx_t &ctx) const override {
50         status_t status = status::success;
51         const auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
52         const auto dst = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_DST, status);
53         CHECK(status);
54         const auto ws = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_WORKSPACE, status);
55         CHECK(status);
56 
57         const auto ker = ker_.get();
58         parallel_nd(N_, H_ * W_, [&](int n, int pixel_id) {
59             typename lrn::jit_avx512_common_lrn_kernel_fwd_t<
60                     d_type>::jit_args_fwd_t args;
61             const auto offset = n * C_ * H_ * W_ + pixel_id * C_;
62             const auto ws_offset0 = offset * 2;
63             const auto ws_offset1 = ws_offset0 + C_;
64 
65             args.src = &src[offset];
66             args.dst = &dst[offset];
67             args.ws0 = ws ? &ws[ws_offset0] : nullptr;
68             args.ws1 = ws ? &ws[ws_offset1] : nullptr;
69 
70             (*ker)(&args);
71         });
72 
73         return status::success;
74     }
75 
76     virtual ~lrn_avx512_nhwc_executor_fwd_t() = default;
77 
78 private:
79     std::unique_ptr<jit_avx512_common_lrn_kernel_fwd_nhwc_t<d_type>> ker_;
80     const int N_;
81     const int C_;
82     const int H_;
83     const int W_;
84 };
85 template <::dnnl::impl::data_type_t d_type, typename PD_T>
86 class lrn_avx512_nhwc_executor_bwd_t : public i_lrn_executor_t {
87 public:
lrn_avx512_nhwc_executor_bwd_t(const PD_T * pd)88     lrn_avx512_nhwc_executor_bwd_t(const PD_T *pd)
89         : ker_ {utils::make_unique<
90                 lrn::jit_avx512_common_lrn_kernel_bwd_nhwc_t<d_type>>(pd->C(),
91                 pd->desc()->lrn_alpha / pd->desc()->local_size,
92                 pd->desc()->lrn_beta, pd->desc()->local_size)}
93         , N_(pd->MB())
94         , C_(pd->C())
95         , H_(pd->H())
96         , W_(pd->W()) {}
97     using data_t = typename prec_traits<d_type>::type;
98 
create_kernel()99     status_t create_kernel() override { return ker_->create_kernel(); }
100 
execute(const exec_ctx_t & ctx) const101     status_t execute(const exec_ctx_t &ctx) const override {
102         status_t status = status::success;
103         auto src = CTX_IN_MEM(data_t *, DNNL_ARG_SRC);
104         auto diff_src = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_DIFF_SRC, status);
105         CHECK(status);
106         auto diff_dst = CTX_IN_MEM(data_t *, DNNL_ARG_DIFF_DST);
107         auto ws = CTX_IN_MEM(data_t *, DNNL_ARG_WORKSPACE);
108 
109         const auto ker = ker_.get();
110         parallel_nd(N_, H_ * W_, [&](int n, int pixel_id) {
111             typename lrn::jit_avx512_common_lrn_kernel_bwd_nhwc_t<
112                     d_type>::jit_args_bwd_t args;
113             const auto offset = n * C_ * H_ * W_ + pixel_id * C_;
114             const auto ws_offset0 = offset * 2;
115             const auto ws_offset1 = ws_offset0 + C_;
116 
117             args.src = &src[offset];
118             args.diff_dst = &diff_dst[offset];
119             args.ws0 = &ws[ws_offset0];
120             args.ws1 = &ws[ws_offset1];
121             args.diff_src = &diff_src[offset];
122 
123             (*ker)(&args);
124         });
125 
126         return status::success;
127     }
128 
129     virtual ~lrn_avx512_nhwc_executor_bwd_t() = default;
130 
131 private:
132     std::unique_ptr<jit_avx512_common_lrn_kernel_bwd_nhwc_t<d_type>> ker_;
133     const int N_;
134     const int C_;
135     const int H_;
136     const int W_;
137 };
138 
139 } // namespace lrn
140 } // namespace x64
141 } // namespace cpu
142 } // namespace impl
143 } // namespace dnnl
144 
145 #endif
146