1 /******************************************************************************* 2 * Copyright 2019-2021 Intel Corporation 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 *******************************************************************************/ 16 17 #ifndef GPU_OCL_REF_RESAMPLING_HPP 18 #define GPU_OCL_REF_RESAMPLING_HPP 19 20 #include "common/c_types_map.hpp" 21 #include "common/nstl.hpp" 22 #include "common/primitive.hpp" 23 #include "common/type_helpers.hpp" 24 #include "gpu/compute/compute.hpp" 25 #include "gpu/gpu_primitive.hpp" 26 #include "gpu/gpu_resampling_pd.hpp" 27 #include "gpu/gpu_resource.hpp" 28 #include "gpu/ocl/ocl_stream.hpp" 29 #include "gpu/ocl/ocl_utils.hpp" 30 #include "gpu/primitive_conf.hpp" 31 32 namespace dnnl { 33 namespace impl { 34 namespace gpu { 35 namespace ocl { 36 37 struct ref_resampling_fwd_t : public gpu_primitive_t { 38 using gpu_primitive_t::gpu_primitive_t; 39 struct pd_t : public gpu_resampling_fwd_pd_t { pd_tdnnl::impl::gpu::ocl::ref_resampling_fwd_t::pd_t40 pd_t(const resampling_desc_t *adesc, const primitive_attr_t *attr, 41 const resampling_fwd_pd_t *hint_fwd_pd) 42 : gpu_resampling_fwd_pd_t(adesc, attr, hint_fwd_pd) {} ~pd_tdnnl::impl::gpu::ocl::ref_resampling_fwd_t::pd_t43 virtual ~pd_t() {} 44 45 DECLARE_COMMON_PD_T("ref:any", ref_resampling_fwd_t); 46 initdnnl::impl::gpu::ocl::ref_resampling_fwd_t::pd_t47 status_t init(engine_t *engine) { 48 using namespace data_type; 49 assert(engine->kind() == engine_kind::gpu); 50 using sm = primitive_attr_t::skip_mask_t; 51 const auto attr_skip_mask = sm::post_ops; 52 53 auto *compute_engine 54 = utils::downcast<compute::compute_engine_t *>(engine); 55 bool ok = is_fwd() && set_default_params() == status::success 56 && attr()->has_default_values(attr_skip_mask) 57 && post_ops_with_binary_ok(attr(), dst_md()->data_type, 5) 58 && attr_.set_default_formats(dst_md(0)) == status::success; 59 if (!ok) return status::unimplemented; 60 61 dispatch = compute_engine->create_dispatch(dst_md()); 62 dispatch.define_dim("MB", 0, dst_md()->padded_dims[0]); 63 dispatch.define_dim("C", 1, dst_md()->padded_dims[1]); 64 dispatch.define_dim("OD", nstl::max(2, dst_md()->ndims - 3), OD()); 65 dispatch.define_dim("OH", nstl::max(2, dst_md()->ndims - 2), OH()); 66 dispatch.define_dim("OW", nstl::max(2, dst_md()->ndims - 1), OW()); 67 dispatch.generate(); 68 attr_info = attr_info_t::create(attr()); 69 70 return status::success; 71 } 72 compute::dispatch_t dispatch; 73 attr_info_t attr_info; 74 }; 75 initdnnl::impl::gpu::ocl::ref_resampling_fwd_t76 status_t init(engine_t *engine) override { 77 using namespace alg_kind; 78 79 compute::kernel_ctx_t kernel_ctx; 80 81 status_t status = status::success; 82 const auto *desc = pd()->desc(); 83 84 kernel_ctx.set_data_type(pd()->src_md()->data_type); 85 86 kernel_ctx.define_int("IS_FWD", 1); 87 88 switch (desc->alg_kind) { 89 case resampling_nearest: 90 kernel_ctx.define_int("RESAMPLING_ALG_NEAREST", 1); 91 break; 92 case resampling_linear: 93 kernel_ctx.define_int("RESAMPLING_ALG_LINEAR", 1); 94 break; 95 default: status = status::unimplemented; 96 } 97 if (status != status::success) return status; 98 99 const memory_desc_wrapper src_d(pd()->src_md()); 100 const memory_desc_wrapper dst_d(pd()->dst_md()); 101 const int ndims = dst_d.ndims(); 102 103 kernel_ctx.define_int("NDIMS", ndims); 104 kernel_ctx.define_int("MB", pd()->MB()); 105 kernel_ctx.define_int("C", pd()->C()); 106 kernel_ctx.define_int("ID", pd()->ID()); 107 kernel_ctx.define_int("IH", pd()->IH()); 108 kernel_ctx.define_int("IW", pd()->IW()); 109 kernel_ctx.define_int("OD", pd()->OD()); 110 kernel_ctx.define_int("OH", pd()->OH()); 111 kernel_ctx.define_int("OW", pd()->OW()); 112 kernel_ctx.define_float("FD", pd()->FD()); 113 kernel_ctx.define_float("FH", pd()->FH()); 114 kernel_ctx.define_float("FW", pd()->FW()); 115 116 offsets_t off; 117 set_offsets(src_d, off.src_off); 118 set_offsets(dst_d, off.dst_off); 119 def_offsets(off.src_off, kernel_ctx, "SRC", ndims); 120 def_offsets(off.dst_off, kernel_ctx, "DST", ndims); 121 def_data_type(kernel_ctx, dst_d.data_type(), "DST"); 122 123 def_attr_info(kernel_ctx, pd()->attr_info, pd()->attr()->post_ops_); 124 def_dispatch(kernel_ctx, pd()->dispatch); 125 126 create_kernel(engine, &kernel_, "ref_resampling_fwd", kernel_ctx); 127 if (!kernel_) return status::runtime_error; 128 129 return status::success; 130 } 131 executednnl::impl::gpu::ocl::ref_resampling_fwd_t132 status_t execute(const exec_ctx_t &ctx) const override { 133 return execute_forward(ctx); 134 } 135 136 private: 137 status_t execute_forward(const exec_ctx_t &ctx) const; pddnnl::impl::gpu::ocl::ref_resampling_fwd_t138 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } 139 compute::kernel_t kernel_; 140 }; 141 142 struct ref_resampling_bwd_t : public gpu_primitive_t { 143 using gpu_primitive_t::gpu_primitive_t; 144 struct pd_t : public gpu_resampling_bwd_pd_t { pd_tdnnl::impl::gpu::ocl::ref_resampling_bwd_t::pd_t145 pd_t(const resampling_desc_t *adesc, const primitive_attr_t *attr, 146 const resampling_fwd_pd_t *hint_fwd_pd) 147 : gpu_resampling_bwd_pd_t(adesc, attr, hint_fwd_pd) {} ~pd_tdnnl::impl::gpu::ocl::ref_resampling_bwd_t::pd_t148 virtual ~pd_t() {} 149 150 DECLARE_COMMON_PD_T("ref:any", ref_resampling_bwd_t); 151 initdnnl::impl::gpu::ocl::ref_resampling_bwd_t::pd_t152 status_t init(engine_t *engine) { 153 using namespace data_type; 154 assert(engine->kind() == engine_kind::gpu); 155 auto *compute_engine 156 = utils::downcast<compute::compute_engine_t *>(engine); 157 bool ok = !is_fwd() && set_default_params() == status::success 158 && attr()->has_default_values(); 159 if (!ok) return status::unimplemented; 160 161 dispatch = compute_engine->create_dispatch(diff_src_md()); 162 dispatch.define_dim("MB", 0, diff_src_md()->padded_dims[0]); 163 dispatch.define_dim("C", 1, diff_src_md()->padded_dims[1]); 164 dispatch.define_dim( 165 "ID", nstl::max(2, diff_src_md()->ndims - 3), ID()); 166 dispatch.define_dim( 167 "IH", nstl::max(2, diff_src_md()->ndims - 2), IH()); 168 dispatch.define_dim( 169 "IW", nstl::max(2, diff_src_md()->ndims - 1), IW()); 170 dispatch.generate(); 171 172 return status::success; 173 } 174 compute::dispatch_t dispatch; 175 }; 176 initdnnl::impl::gpu::ocl::ref_resampling_bwd_t177 status_t init(engine_t *engine) override { 178 using namespace alg_kind; 179 180 compute::kernel_ctx_t kernel_ctx; 181 182 status_t status = status::success; 183 const auto *desc = pd()->desc(); 184 185 kernel_ctx.set_data_type(pd()->diff_src_md()->data_type); 186 187 kernel_ctx.define_int("IS_BWD", 1); 188 189 switch (desc->alg_kind) { 190 case resampling_nearest: 191 kernel_ctx.define_int("RESAMPLING_ALG_NEAREST", 1); 192 break; 193 case resampling_linear: 194 kernel_ctx.define_int("RESAMPLING_ALG_LINEAR", 1); 195 break; 196 default: status = status::unimplemented; 197 } 198 if (status != status::success) return status; 199 200 const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); 201 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); 202 const int ndims = diff_dst_d.ndims(); 203 204 kernel_ctx.define_int("NDIMS", ndims); 205 kernel_ctx.define_int("MB", pd()->MB()); 206 kernel_ctx.define_int("C", pd()->C()); 207 kernel_ctx.define_int("ID", pd()->ID()); 208 kernel_ctx.define_int("IH", pd()->IH()); 209 kernel_ctx.define_int("IW", pd()->IW()); 210 kernel_ctx.define_int("OD", pd()->OD()); 211 kernel_ctx.define_int("OH", pd()->OH()); 212 kernel_ctx.define_int("OW", pd()->OW()); 213 kernel_ctx.define_float("FD", pd()->FD()); 214 kernel_ctx.define_float("FH", pd()->FH()); 215 kernel_ctx.define_float("FW", pd()->FW()); 216 217 offsets_t off; 218 set_offsets(diff_src_d, off.src_off); 219 set_offsets(diff_dst_d, off.dst_off); 220 def_offsets(off.src_off, kernel_ctx, "SRC", ndims); 221 def_offsets(off.dst_off, kernel_ctx, "DST", ndims); 222 def_data_type(kernel_ctx, diff_dst_d.data_type(), "DST"); 223 224 def_dispatch(kernel_ctx, pd()->dispatch); 225 226 create_kernel(engine, &kernel_, "ref_resampling_bwd", kernel_ctx); 227 if (!kernel_) return status::runtime_error; 228 229 return status::success; 230 } 231 executednnl::impl::gpu::ocl::ref_resampling_bwd_t232 status_t execute(const exec_ctx_t &ctx) const override { 233 return execute_backward(ctx); 234 } 235 236 private: 237 status_t execute_backward(const exec_ctx_t &ctx) const; pddnnl::impl::gpu::ocl::ref_resampling_bwd_t238 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } 239 compute::kernel_t kernel_; 240 }; 241 242 } // namespace ocl 243 } // namespace gpu 244 } // namespace impl 245 } // namespace dnnl 246 247 #endif 248