1 /*******************************************************************************
2 * Copyright 2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #ifndef GPU_OCL_XE_HP_1ST_BWD_CONVOLUTION_HPP
18 #define GPU_OCL_XE_HP_1ST_BWD_CONVOLUTION_HPP
19 
20 #include <assert.h>
21 
22 #include "common/c_types_map.hpp"
23 #include "common/primitive.hpp"
24 #include "gpu/compute/compute.hpp"
25 #include "gpu/gpu_convolution_pd.hpp"
26 #include "gpu/gpu_eltwise_pd.hpp"
27 #include "gpu/gpu_primitive.hpp"
28 #include "gpu/gpu_resource.hpp"
29 #include "gpu/ocl/ocl_stream.hpp"
30 #include "gpu/ocl/ocl_utils.hpp"
31 #include "gpu/primitive_conf.hpp"
32 
33 namespace dnnl {
34 namespace impl {
35 namespace gpu {
36 namespace ocl {
37 
38 struct xe_hp_1st_convolution_bwd_weights_t : public gpu_primitive_t {
39     struct pd_t : public gpu_convolution_bwd_weights_pd_t {
pd_tdnnl::impl::gpu::ocl::xe_hp_1st_convolution_bwd_weights_t::pd_t40         pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
41                 const convolution_fwd_pd_t *hint_fwd_pd)
42             : gpu_convolution_bwd_weights_pd_t(adesc, attr, hint_fwd_pd) {}
43 
44         pd_t(const pd_t &rhs) = default;
45 
46         DECLARE_COMMON_PD_T(
47                 "ocl:xe_hp:1st", xe_hp_1st_convolution_bwd_weights_t);
48 
initdnnl::impl::gpu::ocl::xe_hp_1st_convolution_bwd_weights_t::pd_t49         status_t init(engine_t *engine) {
50             using namespace data_type;
51             using namespace prop_kind;
52             assert(engine->kind() == engine_kind::gpu);
53             auto *compute_engine
54                     = utils::downcast<compute::compute_engine_t *>(engine);
55             if (!compute_engine->is_xe_hp() && !compute_engine->is_xe_hpg())
56                 return status::unimplemented;
57             bool ok = set_default_alg_kind(alg_kind::convolution_direct)
58                     && this->desc()->prop_kind == backward_weights
59                     && this->desc()->alg_kind == alg_kind::convolution_direct
60                     && utils::one_of(this->desc()->diff_weights_desc.data_type,
61                             bf16, f32)
62                     && utils::one_of(this->desc()->src_desc.data_type, bf16)
63                     && utils::one_of(
64                             this->desc()->diff_dst_desc.data_type, bf16)
65                     && compute_engine->mayiuse(
66                             compute::device_ext_t::intel_subgroups)
67                     && compute_engine->mayiuse(compute::device_ext_t::
68                                     intel_subgroup_local_block_io)
69                     && compute_engine->mayiuse(
70                             compute::device_ext_t::khr_int64_base_atomics)
71                     && !has_zero_dim_memory() && attr()->has_default_values();
72             if (!ok) return status::unimplemented;
73 
74             CHECK(init_conf(engine));
75 
76             if (!compute_engine->mayiuse_sub_group(conf.sub_group_size))
77                 return status::unimplemented;
78 
79             if (!IMPLICATION(utils::one_of(bf16,
80                                      this->desc()->diff_weights_desc.data_type,
81                                      this->desc()->src_desc.data_type,
82                                      this->desc()->diff_dst_desc.data_type),
83                         conf.ver == ver_1stconv))
84                 return status::unimplemented;
85 
86             init_scratchpad();
87             return status::success;
88         }
89 
90         status_t init_conf(engine_t *engine);
91         status_t init_kernel_ctx(compute::kernel_ctx_t &kernel_ctx) const;
92 
93         conv_conf_t conf;
94         std::shared_ptr<primitive_desc_t> rpd_wei_;
95         std::shared_ptr<primitive_desc_t> rpd_bia_;
96 
97     private:
98         status_t init_scratchpad();
99     };
100 
xe_hp_1st_convolution_bwd_weights_tdnnl::impl::gpu::ocl::xe_hp_1st_convolution_bwd_weights_t101     xe_hp_1st_convolution_bwd_weights_t(const pd_t *apd)
102         : gpu_primitive_t(apd) {}
103 
initdnnl::impl::gpu::ocl::xe_hp_1st_convolution_bwd_weights_t104     status_t init(engine_t *engine) override {
105         const char *kernel_name;
106 
107         kernel_name = "xe_hp_1st_conv_bwd_weights";
108 
109         if (pd()->conf.reorder_wei) {
110             CHECK(pd()->rpd_wei_->create_primitive(wei_reorder_, engine));
111         }
112         if (pd()->conf.reorder_bias) {
113             CHECK(pd()->rpd_bia_->create_primitive(bia_reorder_, engine));
114         }
115         compute::kernel_ctx_t kernel_ctx;
116         status_t status = pd()->init_kernel_ctx(kernel_ctx);
117         if (status != status::success) return status;
118 
119         create_kernel(engine, &kernel_, kernel_name, kernel_ctx);
120         if (!kernel_) return status::runtime_error;
121         return status::success;
122     }
123 
nested_primitivesdnnl::impl::gpu::ocl::xe_hp_1st_convolution_bwd_weights_t124     primitive_list_t nested_primitives() const override {
125         primitive_list_t prim_list;
126         if (pd()->conf.reorder_wei)
127             prim_list.emplace(prim_list.begin(), wei_reorder_.get());
128         if (pd()->conf.reorder_bias)
129             prim_list.emplace(prim_list.begin(), bia_reorder_.get());
130 
131         return prim_list;
132     }
133 
executednnl::impl::gpu::ocl::xe_hp_1st_convolution_bwd_weights_t134     status_t execute(const exec_ctx_t &ctx) const override {
135         return execute_backward_weights(ctx);
136     }
137 
138 private:
139     status_t execute_backward_weights(const exec_ctx_t &ctx) const;
pddnnl::impl::gpu::ocl::xe_hp_1st_convolution_bwd_weights_t140     const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
141     compute::kernel_t kernel_;
142     std::shared_ptr<primitive_t> wei_reorder_;
143     std::shared_ptr<primitive_t> bia_reorder_;
144 };
145 
146 } // namespace ocl
147 } // namespace gpu
148 } // namespace impl
149 } // namespace dnnl
150 
151 #endif
152 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
153