1 /******************************************************************************* 2 * Copyright 2019-2021 Intel Corporation 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 *******************************************************************************/ 16 17 #ifndef GPU_OCL_REF_LRN_HPP 18 #define GPU_OCL_REF_LRN_HPP 19 20 #include "common/c_types_map.hpp" 21 #include "common/nstl.hpp" 22 #include "common/primitive.hpp" 23 #include "common/type_helpers.hpp" 24 #include "gpu/compute/compute.hpp" 25 #include "gpu/gpu_lrn_pd.hpp" 26 #include "gpu/gpu_primitive.hpp" 27 #include "gpu/gpu_resource.hpp" 28 #include "gpu/ocl/ocl_stream.hpp" 29 #include "gpu/ocl/ocl_utils.hpp" 30 #include "gpu/primitive_conf.hpp" 31 32 namespace dnnl { 33 namespace impl { 34 namespace gpu { 35 namespace ocl { 36 37 struct ref_lrn_fwd_t : public gpu_primitive_t { 38 using gpu_primitive_t::gpu_primitive_t; 39 struct pd_t : public gpu_lrn_fwd_pd_t { pd_tdnnl::impl::gpu::ocl::ref_lrn_fwd_t::pd_t40 pd_t(const lrn_desc_t *adesc, const primitive_attr_t *attr, 41 const lrn_fwd_pd_t *hint_fwd_pd) 42 : gpu_lrn_fwd_pd_t(adesc, attr, hint_fwd_pd) {} ~pd_tdnnl::impl::gpu::ocl::ref_lrn_fwd_t::pd_t43 virtual ~pd_t() {} 44 45 DECLARE_COMMON_PD_T("ref:any", ref_lrn_fwd_t); 46 initdnnl::impl::gpu::ocl::ref_lrn_fwd_t::pd_t47 status_t init(engine_t *engine) { 48 using namespace data_type; 49 assert(engine->kind() == engine_kind::gpu); 50 auto *compute_engine 51 = utils::downcast<compute::compute_engine_t *>(engine); 52 bool ok = true 53 && utils::one_of(desc()->prop_kind, 54 prop_kind::forward_inference, 55 prop_kind::forward_training) 56 && utils::one_of(desc()->alg_kind, 57 alg_kind::lrn_across_channels, 58 alg_kind::lrn_within_channel) 59 && utils::one_of( 60 desc()->data_desc.data_type, f32, f16, bf16) 61 && attr()->has_default_values() 62 && IMPLICATION(desc()->data_desc.data_type == f16, 63 compute_engine->mayiuse( 64 compute::device_ext_t::khr_fp16)); 65 if (!ok) return status::unimplemented; 66 67 if (desc_.prop_kind == prop_kind::forward_training) { 68 ws_md_ = *src_md(); 69 if (ws_md_.data_type == data_type::bf16) 70 ws_md_.data_type = data_type::f32; 71 } 72 73 dispatch = compute_engine->create_dispatch(&data_md_); 74 dispatch.define_dim("MB", 0, MB()); 75 dispatch.define_dim("IC", 1, C()); 76 dispatch.define_dim("ID", nstl::max(1, data_md_.ndims - 3), D()); 77 dispatch.define_dim("IH", nstl::max(1, data_md_.ndims - 2), H()); 78 dispatch.define_dim("IW", nstl::max(1, data_md_.ndims - 1), W()); 79 dispatch.generate(); 80 81 return status::success; 82 } 83 84 compute::dispatch_t dispatch; 85 }; 86 initdnnl::impl::gpu::ocl::ref_lrn_fwd_t87 status_t init(engine_t *engine) override { 88 using namespace alg_kind; 89 90 compute::kernel_ctx_t kernel_ctx; 91 92 status_t status = status::success; 93 const auto *desc = pd()->desc(); 94 95 kernel_ctx.set_data_type(desc->data_desc.data_type); 96 97 kernel_ctx.define_int("IS_FWD", 1); 98 99 if (desc->prop_kind == prop_kind::forward_training) 100 kernel_ctx.define_int("IS_TRAINING", 1); 101 102 switch (desc->alg_kind) { 103 case lrn_across_channels: 104 kernel_ctx.define_int("ACROSS_CHANNEL", 1); 105 break; 106 case lrn_within_channel: 107 kernel_ctx.define_int("WITHIN_CHANNEL", 1); 108 break; 109 default: status = status::unimplemented; 110 } 111 if (status != status::success) return status; 112 113 const memory_desc_wrapper src_d(pd()->src_md()); 114 const memory_desc_wrapper dst_d(pd()->dst_md()); 115 const int ndims = src_d.ndims(); 116 117 kernel_ctx.define_int("NDIMS", ndims); 118 kernel_ctx.define_int("MB", pd()->MB()); 119 kernel_ctx.define_int("IC", pd()->C()); 120 kernel_ctx.define_int("ID", pd()->D()); 121 kernel_ctx.define_int("IH", pd()->H()); 122 kernel_ctx.define_int("IW", pd()->W()); 123 124 const uint32_t round_norm_size = desc->local_size; 125 uint32_t num_elements = pow(round_norm_size, nstl::max(0, ndims - 2)); 126 if (desc->alg_kind == lrn_across_channels) { 127 num_elements = round_norm_size; 128 } 129 const float num_element_div = 1.f / (float)num_elements; 130 const auto padding = (desc->local_size - 1) / 2; 131 132 kernel_ctx.define_float("NUM_ELEMENTS_DIV", num_element_div); 133 kernel_ctx.define_int("PADDING", padding); 134 kernel_ctx.define_int( 135 "LOCAL_SIZE", desc->local_size - 1 + desc->local_size % 2); 136 kernel_ctx.define_float("LRN_ALPHA", desc->lrn_alpha); 137 kernel_ctx.define_float("LRN_BETA", desc->lrn_beta); 138 kernel_ctx.define_float("LRN_K", desc->lrn_k); 139 140 offsets_t off; 141 set_offsets(src_d, off.src_off); 142 set_offsets(dst_d, off.dst_off); 143 def_offsets(off.src_off, kernel_ctx, "SRC", ndims); 144 def_offsets(off.dst_off, kernel_ctx, "DST", ndims); 145 146 def_dispatch(kernel_ctx, pd()->dispatch); 147 148 create_kernel(engine, &kernel_, "ref_lrn_fwd", kernel_ctx); 149 if (!kernel_) return status::runtime_error; 150 151 return status::success; 152 } 153 executednnl::impl::gpu::ocl::ref_lrn_fwd_t154 status_t execute(const exec_ctx_t &ctx) const override { 155 return execute_forward(ctx); 156 } 157 158 private: 159 status_t execute_forward(const exec_ctx_t &ctx) const; pddnnl::impl::gpu::ocl::ref_lrn_fwd_t160 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } 161 compute::kernel_t kernel_; 162 }; 163 164 struct ref_lrn_bwd_t : public gpu_primitive_t { 165 using gpu_primitive_t::gpu_primitive_t; 166 struct pd_t : public gpu_lrn_bwd_pd_t { pd_tdnnl::impl::gpu::ocl::ref_lrn_bwd_t::pd_t167 pd_t(const lrn_desc_t *adesc, const primitive_attr_t *attr, 168 const lrn_fwd_pd_t *hint_fwd_pd) 169 : gpu_lrn_bwd_pd_t(adesc, attr, hint_fwd_pd) {} ~pd_tdnnl::impl::gpu::ocl::ref_lrn_bwd_t::pd_t170 virtual ~pd_t() {} 171 172 DECLARE_COMMON_PD_T("ref:any", ref_lrn_bwd_t); 173 initdnnl::impl::gpu::ocl::ref_lrn_bwd_t::pd_t174 status_t init(engine_t *engine) { 175 assert(engine->kind() == engine_kind::gpu); 176 auto *compute_engine 177 = utils::downcast<compute::compute_engine_t *>(engine); 178 bool ok = true 179 && utils::one_of( 180 desc()->prop_kind, prop_kind::backward_data) 181 && utils::one_of(desc()->alg_kind, 182 alg_kind::lrn_across_channels, 183 alg_kind::lrn_within_channel) 184 && utils::one_of(desc()->data_desc.data_type, 185 data_type::f32, data_type::bf16) 186 && set_default_formats_common() 187 && attr()->has_default_values() 188 && IMPLICATION( 189 desc()->data_desc.data_type == data_type::f16, 190 compute_engine->mayiuse( 191 compute::device_ext_t::khr_fp16)); 192 if (!ok) return status::unimplemented; 193 194 ws_md_ = *src_md(); 195 if (ws_md_.data_type == data_type::bf16) 196 ws_md_.data_type = data_type::f32; 197 if (!compare_ws(hint_fwd_pd_)) return status::unimplemented; 198 199 dispatch = compute_engine->create_dispatch(&diff_data_md_); 200 dispatch.define_dim("MB", 0, MB()); 201 dispatch.define_dim("IC", 1, C()); 202 dispatch.define_dim("ID", nstl::max(1, data_md_.ndims - 3), D()); 203 dispatch.define_dim("IH", nstl::max(1, data_md_.ndims - 2), H()); 204 dispatch.define_dim("IW", nstl::max(1, data_md_.ndims - 1), W()); 205 dispatch.generate(); 206 207 return status::success; 208 } 209 210 compute::dispatch_t dispatch; 211 }; 212 initdnnl::impl::gpu::ocl::ref_lrn_bwd_t213 status_t init(engine_t *engine) override { 214 using namespace alg_kind; 215 216 compute::kernel_ctx_t kernel_ctx; 217 218 status_t status = status::success; 219 const auto *desc = pd()->desc(); 220 221 kernel_ctx.set_data_type(desc->data_desc.data_type); 222 223 kernel_ctx.define_int("IS_BWD", 1); 224 225 switch (desc->alg_kind) { 226 case lrn_across_channels: 227 kernel_ctx.define_int("ACROSS_CHANNEL", 1); 228 break; 229 case lrn_within_channel: 230 kernel_ctx.define_int("WITHIN_CHANNEL", 1); 231 break; 232 default: status = status::unimplemented; 233 } 234 if (status != status::success) return status; 235 236 const memory_desc_wrapper src_d(pd()->src_md()); 237 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); 238 const int ndims = src_d.ndims(); 239 240 kernel_ctx.define_int("NDIMS", ndims); 241 kernel_ctx.define_int("MB", pd()->MB()); 242 kernel_ctx.define_int("IC", pd()->C()); 243 kernel_ctx.define_int("ID", pd()->D()); 244 kernel_ctx.define_int("IH", pd()->H()); 245 kernel_ctx.define_int("IW", pd()->W()); 246 247 const uint32_t round_norm_size = desc->local_size; 248 uint32_t num_elements = pow(round_norm_size, nstl::max(0, ndims - 2)); 249 if (desc->alg_kind == lrn_across_channels) { 250 num_elements = round_norm_size; 251 } 252 const float num_element_div = 1.f / (float)num_elements; 253 const auto padding = (desc->local_size - 1) / 2; 254 255 kernel_ctx.define_float("NUM_ELEMENTS_DIV", num_element_div); 256 kernel_ctx.define_int("PADDING", padding); 257 kernel_ctx.define_int( 258 "LOCAL_SIZE", desc->local_size - 1 + desc->local_size % 2); 259 kernel_ctx.define_float("LRN_ALPHA", desc->lrn_alpha); 260 kernel_ctx.define_float("LRN_BETA", desc->lrn_beta); 261 kernel_ctx.define_float("LRN_K", desc->lrn_k); 262 263 offsets_t off; 264 set_offsets(src_d, off.src_off); 265 set_offsets(diff_dst_d, off.dst_off); 266 def_offsets(off.src_off, kernel_ctx, "SRC", ndims); 267 def_offsets(off.dst_off, kernel_ctx, "DST", ndims); 268 269 def_dispatch(kernel_ctx, pd()->dispatch); 270 271 create_kernel(engine, &kernel_, "ref_lrn_bwd", kernel_ctx); 272 if (!kernel_) return status::runtime_error; 273 274 return status::success; 275 } 276 executednnl::impl::gpu::ocl::ref_lrn_bwd_t277 status_t execute(const exec_ctx_t &ctx) const override { 278 return execute_backward(ctx); 279 } 280 281 private: 282 status_t execute_backward(const exec_ctx_t &ctx) const; pddnnl::impl::gpu::ocl::ref_lrn_bwd_t283 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } 284 285 compute::kernel_t kernel_; 286 }; 287 288 } // namespace ocl 289 } // namespace gpu 290 } // namespace impl 291 } // namespace dnnl 292 293 #endif 294