1 /*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #ifndef GPU_OCL_REF_LRN_HPP
18 #define GPU_OCL_REF_LRN_HPP
19 
20 #include "common/c_types_map.hpp"
21 #include "common/nstl.hpp"
22 #include "common/primitive.hpp"
23 #include "common/type_helpers.hpp"
24 #include "gpu/compute/compute.hpp"
25 #include "gpu/gpu_lrn_pd.hpp"
26 #include "gpu/gpu_primitive.hpp"
27 #include "gpu/gpu_resource.hpp"
28 #include "gpu/ocl/ocl_stream.hpp"
29 #include "gpu/ocl/ocl_utils.hpp"
30 #include "gpu/primitive_conf.hpp"
31 
32 namespace dnnl {
33 namespace impl {
34 namespace gpu {
35 namespace ocl {
36 
37 struct ref_lrn_fwd_t : public gpu_primitive_t {
38     using gpu_primitive_t::gpu_primitive_t;
39     struct pd_t : public gpu_lrn_fwd_pd_t {
pd_tdnnl::impl::gpu::ocl::ref_lrn_fwd_t::pd_t40         pd_t(const lrn_desc_t *adesc, const primitive_attr_t *attr,
41                 const lrn_fwd_pd_t *hint_fwd_pd)
42             : gpu_lrn_fwd_pd_t(adesc, attr, hint_fwd_pd) {}
~pd_tdnnl::impl::gpu::ocl::ref_lrn_fwd_t::pd_t43         virtual ~pd_t() {}
44 
45         DECLARE_COMMON_PD_T("ref:any", ref_lrn_fwd_t);
46 
initdnnl::impl::gpu::ocl::ref_lrn_fwd_t::pd_t47         status_t init(engine_t *engine) {
48             using namespace data_type;
49             assert(engine->kind() == engine_kind::gpu);
50             auto *compute_engine
51                     = utils::downcast<compute::compute_engine_t *>(engine);
52             bool ok = true
53                     && utils::one_of(desc()->prop_kind,
54                             prop_kind::forward_inference,
55                             prop_kind::forward_training)
56                     && utils::one_of(desc()->alg_kind,
57                             alg_kind::lrn_across_channels,
58                             alg_kind::lrn_within_channel)
59                     && utils::one_of(
60                             desc()->data_desc.data_type, f32, f16, bf16)
61                     && attr()->has_default_values()
62                     && IMPLICATION(desc()->data_desc.data_type == f16,
63                             compute_engine->mayiuse(
64                                     compute::device_ext_t::khr_fp16));
65             if (!ok) return status::unimplemented;
66 
67             if (desc_.prop_kind == prop_kind::forward_training) {
68                 ws_md_ = *src_md();
69                 if (ws_md_.data_type == data_type::bf16)
70                     ws_md_.data_type = data_type::f32;
71             }
72 
73             dispatch = compute_engine->create_dispatch(&data_md_);
74             dispatch.define_dim("MB", 0, MB());
75             dispatch.define_dim("IC", 1, C());
76             dispatch.define_dim("ID", nstl::max(1, data_md_.ndims - 3), D());
77             dispatch.define_dim("IH", nstl::max(1, data_md_.ndims - 2), H());
78             dispatch.define_dim("IW", nstl::max(1, data_md_.ndims - 1), W());
79             dispatch.generate();
80 
81             return status::success;
82         }
83 
84         compute::dispatch_t dispatch;
85     };
86 
initdnnl::impl::gpu::ocl::ref_lrn_fwd_t87     status_t init(engine_t *engine) override {
88         using namespace alg_kind;
89 
90         compute::kernel_ctx_t kernel_ctx;
91 
92         status_t status = status::success;
93         const auto *desc = pd()->desc();
94 
95         kernel_ctx.set_data_type(desc->data_desc.data_type);
96 
97         kernel_ctx.define_int("IS_FWD", 1);
98 
99         if (desc->prop_kind == prop_kind::forward_training)
100             kernel_ctx.define_int("IS_TRAINING", 1);
101 
102         switch (desc->alg_kind) {
103             case lrn_across_channels:
104                 kernel_ctx.define_int("ACROSS_CHANNEL", 1);
105                 break;
106             case lrn_within_channel:
107                 kernel_ctx.define_int("WITHIN_CHANNEL", 1);
108                 break;
109             default: status = status::unimplemented;
110         }
111         if (status != status::success) return status;
112 
113         const memory_desc_wrapper src_d(pd()->src_md());
114         const memory_desc_wrapper dst_d(pd()->dst_md());
115         const int ndims = src_d.ndims();
116 
117         kernel_ctx.define_int("NDIMS", ndims);
118         kernel_ctx.define_int("MB", pd()->MB());
119         kernel_ctx.define_int("IC", pd()->C());
120         kernel_ctx.define_int("ID", pd()->D());
121         kernel_ctx.define_int("IH", pd()->H());
122         kernel_ctx.define_int("IW", pd()->W());
123 
124         const uint32_t round_norm_size = desc->local_size;
125         uint32_t num_elements = pow(round_norm_size, nstl::max(0, ndims - 2));
126         if (desc->alg_kind == lrn_across_channels) {
127             num_elements = round_norm_size;
128         }
129         const float num_element_div = 1.f / (float)num_elements;
130         const auto padding = (desc->local_size - 1) / 2;
131 
132         kernel_ctx.define_float("NUM_ELEMENTS_DIV", num_element_div);
133         kernel_ctx.define_int("PADDING", padding);
134         kernel_ctx.define_int(
135                 "LOCAL_SIZE", desc->local_size - 1 + desc->local_size % 2);
136         kernel_ctx.define_float("LRN_ALPHA", desc->lrn_alpha);
137         kernel_ctx.define_float("LRN_BETA", desc->lrn_beta);
138         kernel_ctx.define_float("LRN_K", desc->lrn_k);
139 
140         offsets_t off;
141         set_offsets(src_d, off.src_off);
142         set_offsets(dst_d, off.dst_off);
143         def_offsets(off.src_off, kernel_ctx, "SRC", ndims);
144         def_offsets(off.dst_off, kernel_ctx, "DST", ndims);
145 
146         def_dispatch(kernel_ctx, pd()->dispatch);
147 
148         create_kernel(engine, &kernel_, "ref_lrn_fwd", kernel_ctx);
149         if (!kernel_) return status::runtime_error;
150 
151         return status::success;
152     }
153 
executednnl::impl::gpu::ocl::ref_lrn_fwd_t154     status_t execute(const exec_ctx_t &ctx) const override {
155         return execute_forward(ctx);
156     }
157 
158 private:
159     status_t execute_forward(const exec_ctx_t &ctx) const;
pddnnl::impl::gpu::ocl::ref_lrn_fwd_t160     const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
161     compute::kernel_t kernel_;
162 };
163 
164 struct ref_lrn_bwd_t : public gpu_primitive_t {
165     using gpu_primitive_t::gpu_primitive_t;
166     struct pd_t : public gpu_lrn_bwd_pd_t {
pd_tdnnl::impl::gpu::ocl::ref_lrn_bwd_t::pd_t167         pd_t(const lrn_desc_t *adesc, const primitive_attr_t *attr,
168                 const lrn_fwd_pd_t *hint_fwd_pd)
169             : gpu_lrn_bwd_pd_t(adesc, attr, hint_fwd_pd) {}
~pd_tdnnl::impl::gpu::ocl::ref_lrn_bwd_t::pd_t170         virtual ~pd_t() {}
171 
172         DECLARE_COMMON_PD_T("ref:any", ref_lrn_bwd_t);
173 
initdnnl::impl::gpu::ocl::ref_lrn_bwd_t::pd_t174         status_t init(engine_t *engine) {
175             assert(engine->kind() == engine_kind::gpu);
176             auto *compute_engine
177                     = utils::downcast<compute::compute_engine_t *>(engine);
178             bool ok = true
179                     && utils::one_of(
180                             desc()->prop_kind, prop_kind::backward_data)
181                     && utils::one_of(desc()->alg_kind,
182                             alg_kind::lrn_across_channels,
183                             alg_kind::lrn_within_channel)
184                     && utils::one_of(desc()->data_desc.data_type,
185                             data_type::f32, data_type::bf16)
186                     && set_default_formats_common()
187                     && attr()->has_default_values()
188                     && IMPLICATION(
189                             desc()->data_desc.data_type == data_type::f16,
190                             compute_engine->mayiuse(
191                                     compute::device_ext_t::khr_fp16));
192             if (!ok) return status::unimplemented;
193 
194             ws_md_ = *src_md();
195             if (ws_md_.data_type == data_type::bf16)
196                 ws_md_.data_type = data_type::f32;
197             if (!compare_ws(hint_fwd_pd_)) return status::unimplemented;
198 
199             dispatch = compute_engine->create_dispatch(&diff_data_md_);
200             dispatch.define_dim("MB", 0, MB());
201             dispatch.define_dim("IC", 1, C());
202             dispatch.define_dim("ID", nstl::max(1, data_md_.ndims - 3), D());
203             dispatch.define_dim("IH", nstl::max(1, data_md_.ndims - 2), H());
204             dispatch.define_dim("IW", nstl::max(1, data_md_.ndims - 1), W());
205             dispatch.generate();
206 
207             return status::success;
208         }
209 
210         compute::dispatch_t dispatch;
211     };
212 
initdnnl::impl::gpu::ocl::ref_lrn_bwd_t213     status_t init(engine_t *engine) override {
214         using namespace alg_kind;
215 
216         compute::kernel_ctx_t kernel_ctx;
217 
218         status_t status = status::success;
219         const auto *desc = pd()->desc();
220 
221         kernel_ctx.set_data_type(desc->data_desc.data_type);
222 
223         kernel_ctx.define_int("IS_BWD", 1);
224 
225         switch (desc->alg_kind) {
226             case lrn_across_channels:
227                 kernel_ctx.define_int("ACROSS_CHANNEL", 1);
228                 break;
229             case lrn_within_channel:
230                 kernel_ctx.define_int("WITHIN_CHANNEL", 1);
231                 break;
232             default: status = status::unimplemented;
233         }
234         if (status != status::success) return status;
235 
236         const memory_desc_wrapper src_d(pd()->src_md());
237         const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
238         const int ndims = src_d.ndims();
239 
240         kernel_ctx.define_int("NDIMS", ndims);
241         kernel_ctx.define_int("MB", pd()->MB());
242         kernel_ctx.define_int("IC", pd()->C());
243         kernel_ctx.define_int("ID", pd()->D());
244         kernel_ctx.define_int("IH", pd()->H());
245         kernel_ctx.define_int("IW", pd()->W());
246 
247         const uint32_t round_norm_size = desc->local_size;
248         uint32_t num_elements = pow(round_norm_size, nstl::max(0, ndims - 2));
249         if (desc->alg_kind == lrn_across_channels) {
250             num_elements = round_norm_size;
251         }
252         const float num_element_div = 1.f / (float)num_elements;
253         const auto padding = (desc->local_size - 1) / 2;
254 
255         kernel_ctx.define_float("NUM_ELEMENTS_DIV", num_element_div);
256         kernel_ctx.define_int("PADDING", padding);
257         kernel_ctx.define_int(
258                 "LOCAL_SIZE", desc->local_size - 1 + desc->local_size % 2);
259         kernel_ctx.define_float("LRN_ALPHA", desc->lrn_alpha);
260         kernel_ctx.define_float("LRN_BETA", desc->lrn_beta);
261         kernel_ctx.define_float("LRN_K", desc->lrn_k);
262 
263         offsets_t off;
264         set_offsets(src_d, off.src_off);
265         set_offsets(diff_dst_d, off.dst_off);
266         def_offsets(off.src_off, kernel_ctx, "SRC", ndims);
267         def_offsets(off.dst_off, kernel_ctx, "DST", ndims);
268 
269         def_dispatch(kernel_ctx, pd()->dispatch);
270 
271         create_kernel(engine, &kernel_, "ref_lrn_bwd", kernel_ctx);
272         if (!kernel_) return status::runtime_error;
273 
274         return status::success;
275     }
276 
executednnl::impl::gpu::ocl::ref_lrn_bwd_t277     status_t execute(const exec_ctx_t &ctx) const override {
278         return execute_backward(ctx);
279     }
280 
281 private:
282     status_t execute_backward(const exec_ctx_t &ctx) const;
pddnnl::impl::gpu::ocl::ref_lrn_bwd_t283     const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
284 
285     compute::kernel_t kernel_;
286 };
287 
288 } // namespace ocl
289 } // namespace gpu
290 } // namespace impl
291 } // namespace dnnl
292 
293 #endif
294