1 /*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #ifndef GPU_OCL_REF_RESAMPLING_HPP
18 #define GPU_OCL_REF_RESAMPLING_HPP
19 
20 #include "common/c_types_map.hpp"
21 #include "common/nstl.hpp"
22 #include "common/primitive.hpp"
23 #include "common/type_helpers.hpp"
24 #include "gpu/compute/compute.hpp"
25 #include "gpu/gpu_primitive.hpp"
26 #include "gpu/gpu_resampling_pd.hpp"
27 #include "gpu/gpu_resource.hpp"
28 #include "gpu/ocl/ocl_stream.hpp"
29 #include "gpu/ocl/ocl_utils.hpp"
30 #include "gpu/primitive_conf.hpp"
31 
32 namespace dnnl {
33 namespace impl {
34 namespace gpu {
35 namespace ocl {
36 
37 struct ref_resampling_fwd_t : public gpu_primitive_t {
38     using gpu_primitive_t::gpu_primitive_t;
39     struct pd_t : public gpu_resampling_fwd_pd_t {
pd_tdnnl::impl::gpu::ocl::ref_resampling_fwd_t::pd_t40         pd_t(const resampling_desc_t *adesc, const primitive_attr_t *attr,
41                 const resampling_fwd_pd_t *hint_fwd_pd)
42             : gpu_resampling_fwd_pd_t(adesc, attr, hint_fwd_pd) {}
~pd_tdnnl::impl::gpu::ocl::ref_resampling_fwd_t::pd_t43         virtual ~pd_t() {}
44 
45         DECLARE_COMMON_PD_T("ref:any", ref_resampling_fwd_t);
46 
initdnnl::impl::gpu::ocl::ref_resampling_fwd_t::pd_t47         status_t init(engine_t *engine) {
48             using namespace data_type;
49             assert(engine->kind() == engine_kind::gpu);
50             using sm = primitive_attr_t::skip_mask_t;
51             const auto attr_skip_mask = sm::post_ops;
52 
53             auto *compute_engine
54                     = utils::downcast<compute::compute_engine_t *>(engine);
55             bool ok = is_fwd() && set_default_params() == status::success
56                     && attr()->has_default_values(attr_skip_mask)
57                     && post_ops_with_binary_ok(attr(), dst_md()->data_type, 5)
58                     && attr_.set_default_formats(dst_md(0)) == status::success;
59             if (!ok) return status::unimplemented;
60 
61             dispatch = compute_engine->create_dispatch(dst_md());
62             dispatch.define_dim("MB", 0, dst_md()->padded_dims[0]);
63             dispatch.define_dim("C", 1, dst_md()->padded_dims[1]);
64             dispatch.define_dim("OD", nstl::max(2, dst_md()->ndims - 3), OD());
65             dispatch.define_dim("OH", nstl::max(2, dst_md()->ndims - 2), OH());
66             dispatch.define_dim("OW", nstl::max(2, dst_md()->ndims - 1), OW());
67             dispatch.generate();
68             attr_info = attr_info_t::create(attr());
69 
70             return status::success;
71         }
72         compute::dispatch_t dispatch;
73         attr_info_t attr_info;
74     };
75 
initdnnl::impl::gpu::ocl::ref_resampling_fwd_t76     status_t init(engine_t *engine) override {
77         using namespace alg_kind;
78 
79         compute::kernel_ctx_t kernel_ctx;
80 
81         status_t status = status::success;
82         const auto *desc = pd()->desc();
83 
84         kernel_ctx.set_data_type(pd()->src_md()->data_type);
85 
86         kernel_ctx.define_int("IS_FWD", 1);
87 
88         switch (desc->alg_kind) {
89             case resampling_nearest:
90                 kernel_ctx.define_int("RESAMPLING_ALG_NEAREST", 1);
91                 break;
92             case resampling_linear:
93                 kernel_ctx.define_int("RESAMPLING_ALG_LINEAR", 1);
94                 break;
95             default: status = status::unimplemented;
96         }
97         if (status != status::success) return status;
98 
99         const memory_desc_wrapper src_d(pd()->src_md());
100         const memory_desc_wrapper dst_d(pd()->dst_md());
101         const int ndims = dst_d.ndims();
102 
103         kernel_ctx.define_int("NDIMS", ndims);
104         kernel_ctx.define_int("MB", pd()->MB());
105         kernel_ctx.define_int("C", pd()->C());
106         kernel_ctx.define_int("ID", pd()->ID());
107         kernel_ctx.define_int("IH", pd()->IH());
108         kernel_ctx.define_int("IW", pd()->IW());
109         kernel_ctx.define_int("OD", pd()->OD());
110         kernel_ctx.define_int("OH", pd()->OH());
111         kernel_ctx.define_int("OW", pd()->OW());
112         kernel_ctx.define_float("FD", pd()->FD());
113         kernel_ctx.define_float("FH", pd()->FH());
114         kernel_ctx.define_float("FW", pd()->FW());
115 
116         offsets_t off;
117         set_offsets(src_d, off.src_off);
118         set_offsets(dst_d, off.dst_off);
119         def_offsets(off.src_off, kernel_ctx, "SRC", ndims);
120         def_offsets(off.dst_off, kernel_ctx, "DST", ndims);
121         def_data_type(kernel_ctx, dst_d.data_type(), "DST");
122 
123         def_attr_info(kernel_ctx, pd()->attr_info, pd()->attr()->post_ops_);
124         def_dispatch(kernel_ctx, pd()->dispatch);
125 
126         create_kernel(engine, &kernel_, "ref_resampling_fwd", kernel_ctx);
127         if (!kernel_) return status::runtime_error;
128 
129         return status::success;
130     }
131 
executednnl::impl::gpu::ocl::ref_resampling_fwd_t132     status_t execute(const exec_ctx_t &ctx) const override {
133         return execute_forward(ctx);
134     }
135 
136 private:
137     status_t execute_forward(const exec_ctx_t &ctx) const;
pddnnl::impl::gpu::ocl::ref_resampling_fwd_t138     const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
139     compute::kernel_t kernel_;
140 };
141 
142 struct ref_resampling_bwd_t : public gpu_primitive_t {
143     using gpu_primitive_t::gpu_primitive_t;
144     struct pd_t : public gpu_resampling_bwd_pd_t {
pd_tdnnl::impl::gpu::ocl::ref_resampling_bwd_t::pd_t145         pd_t(const resampling_desc_t *adesc, const primitive_attr_t *attr,
146                 const resampling_fwd_pd_t *hint_fwd_pd)
147             : gpu_resampling_bwd_pd_t(adesc, attr, hint_fwd_pd) {}
~pd_tdnnl::impl::gpu::ocl::ref_resampling_bwd_t::pd_t148         virtual ~pd_t() {}
149 
150         DECLARE_COMMON_PD_T("ref:any", ref_resampling_bwd_t);
151 
initdnnl::impl::gpu::ocl::ref_resampling_bwd_t::pd_t152         status_t init(engine_t *engine) {
153             using namespace data_type;
154             assert(engine->kind() == engine_kind::gpu);
155             auto *compute_engine
156                     = utils::downcast<compute::compute_engine_t *>(engine);
157             bool ok = !is_fwd() && set_default_params() == status::success
158                     && attr()->has_default_values();
159             if (!ok) return status::unimplemented;
160 
161             dispatch = compute_engine->create_dispatch(diff_src_md());
162             dispatch.define_dim("MB", 0, diff_src_md()->padded_dims[0]);
163             dispatch.define_dim("C", 1, diff_src_md()->padded_dims[1]);
164             dispatch.define_dim(
165                     "ID", nstl::max(2, diff_src_md()->ndims - 3), ID());
166             dispatch.define_dim(
167                     "IH", nstl::max(2, diff_src_md()->ndims - 2), IH());
168             dispatch.define_dim(
169                     "IW", nstl::max(2, diff_src_md()->ndims - 1), IW());
170             dispatch.generate();
171 
172             return status::success;
173         }
174         compute::dispatch_t dispatch;
175     };
176 
initdnnl::impl::gpu::ocl::ref_resampling_bwd_t177     status_t init(engine_t *engine) override {
178         using namespace alg_kind;
179 
180         compute::kernel_ctx_t kernel_ctx;
181 
182         status_t status = status::success;
183         const auto *desc = pd()->desc();
184 
185         kernel_ctx.set_data_type(pd()->diff_src_md()->data_type);
186 
187         kernel_ctx.define_int("IS_BWD", 1);
188 
189         switch (desc->alg_kind) {
190             case resampling_nearest:
191                 kernel_ctx.define_int("RESAMPLING_ALG_NEAREST", 1);
192                 break;
193             case resampling_linear:
194                 kernel_ctx.define_int("RESAMPLING_ALG_LINEAR", 1);
195                 break;
196             default: status = status::unimplemented;
197         }
198         if (status != status::success) return status;
199 
200         const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
201         const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
202         const int ndims = diff_dst_d.ndims();
203 
204         kernel_ctx.define_int("NDIMS", ndims);
205         kernel_ctx.define_int("MB", pd()->MB());
206         kernel_ctx.define_int("C", pd()->C());
207         kernel_ctx.define_int("ID", pd()->ID());
208         kernel_ctx.define_int("IH", pd()->IH());
209         kernel_ctx.define_int("IW", pd()->IW());
210         kernel_ctx.define_int("OD", pd()->OD());
211         kernel_ctx.define_int("OH", pd()->OH());
212         kernel_ctx.define_int("OW", pd()->OW());
213         kernel_ctx.define_float("FD", pd()->FD());
214         kernel_ctx.define_float("FH", pd()->FH());
215         kernel_ctx.define_float("FW", pd()->FW());
216 
217         offsets_t off;
218         set_offsets(diff_src_d, off.src_off);
219         set_offsets(diff_dst_d, off.dst_off);
220         def_offsets(off.src_off, kernel_ctx, "SRC", ndims);
221         def_offsets(off.dst_off, kernel_ctx, "DST", ndims);
222         def_data_type(kernel_ctx, diff_dst_d.data_type(), "DST");
223 
224         def_dispatch(kernel_ctx, pd()->dispatch);
225 
226         create_kernel(engine, &kernel_, "ref_resampling_bwd", kernel_ctx);
227         if (!kernel_) return status::runtime_error;
228 
229         return status::success;
230     }
231 
executednnl::impl::gpu::ocl::ref_resampling_bwd_t232     status_t execute(const exec_ctx_t &ctx) const override {
233         return execute_backward(ctx);
234     }
235 
236 private:
237     status_t execute_backward(const exec_ctx_t &ctx) const;
pddnnl::impl::gpu::ocl::ref_resampling_bwd_t238     const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
239     compute::kernel_t kernel_;
240 };
241 
242 } // namespace ocl
243 } // namespace gpu
244 } // namespace impl
245 } // namespace dnnl
246 
247 #endif
248