1 /******************************************************************************* 2 * Copyright 2019-2021 Intel Corporation 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 *******************************************************************************/ 16 17 #ifndef GPU_OCL_REF_BINARY_HPP 18 #define GPU_OCL_REF_BINARY_HPP 19 20 #include "common/c_types_map.hpp" 21 #include "common/primitive.hpp" 22 #include "gpu/compute/compute.hpp" 23 #include "gpu/gpu_binary_pd.hpp" 24 #include "gpu/gpu_primitive.hpp" 25 #include "gpu/gpu_resource.hpp" 26 #include "gpu/ocl/ocl_stream.hpp" 27 #include "gpu/ocl/ocl_utils.hpp" 28 #include "gpu/primitive_conf.hpp" 29 30 namespace dnnl { 31 namespace impl { 32 namespace gpu { 33 namespace ocl { 34 35 struct ref_binary_t : public gpu_primitive_t { 36 using gpu_primitive_t::gpu_primitive_t; 37 struct pd_t : public gpu_binary_pd_t { 38 using gpu_binary_pd_t::gpu_binary_pd_t; 39 40 DECLARE_COMMON_PD_T("ocl:ref:any", ref_binary_t); 41 initdnnl::impl::gpu::ocl::ref_binary_t::pd_t42 status_t init(engine_t *engine) { 43 using namespace data_type; 44 using sm = primitive_attr_t::skip_mask_t; 45 46 const auto attr_skip_mask = sm::post_ops | sm::scales; 47 bool ok = set_default_params() == status::success 48 && (utils::everyone_is(bf16, src_md(0)->data_type, 49 src_md(1)->data_type, dst_md()->data_type) 50 || (utils::one_of( 51 src_md(0)->data_type, f16, f32, s8, u8) 52 && utils::one_of(src_md(1)->data_type, f16, 53 f32, s8, u8) 54 && utils::one_of(dst_md()->data_type, f16, 55 f32, s8, u8))) 56 && IMPLICATION(!attr()->scales_.has_default_values(), 57 check_scales_mask()) 58 && attr()->has_default_values(attr_skip_mask) 59 && post_ops_with_binary_ok( 60 attr(), dst_md()->data_type, MAX_NDIMS) 61 && attr_.set_default_formats(dst_md(0)) == status::success; 62 63 if (!ok) return status::unimplemented; 64 65 return init_conf(engine); 66 } 67 68 status_t init_conf(engine_t *engine); 69 status_t init_kernel_ctx(compute::kernel_ctx_t &kernel_ctx) const; 70 with_scalesdnnl::impl::gpu::ocl::ref_binary_t::pd_t71 bool with_scales(int position) const { 72 return !attr()->scales_.get(position).has_default_values(); 73 } 74 with_scalesdnnl::impl::gpu::ocl::ref_binary_t::pd_t75 bool with_scales() const { 76 return with_scales(DNNL_ARG_SRC_0) || with_scales(DNNL_ARG_SRC_1); 77 } 78 get_scalednnl::impl::gpu::ocl::ref_binary_t::pd_t79 float get_scale(int position) const { 80 return *attr()->scales_.get(position).scales_; 81 } 82 with_eltwisednnl::impl::gpu::ocl::ref_binary_t::pd_t83 bool with_eltwise(int position) const { 84 return attr()->post_ops_.contain(primitive_kind::eltwise, position); 85 } 86 with_sumdnnl::impl::gpu::ocl::ref_binary_t::pd_t87 bool with_sum() const { 88 return attr()->post_ops_.find(primitive_kind::sum) != -1; 89 } 90 eltwise_alphadnnl::impl::gpu::ocl::ref_binary_t::pd_t91 float eltwise_alpha() const { 92 const int eltwise_idx 93 = attr()->post_ops_.find(primitive_kind::eltwise); 94 return eltwise_idx != -1 95 ? attr()->post_ops_.entry_[eltwise_idx].eltwise.alpha 96 : 1.0f; 97 } 98 eltwise_betadnnl::impl::gpu::ocl::ref_binary_t::pd_t99 float eltwise_beta() const { 100 const int eltwise_idx 101 = attr()->post_ops_.find(primitive_kind::eltwise); 102 return eltwise_idx != -1 103 ? attr()->post_ops_.entry_[eltwise_idx].eltwise.beta 104 : 0.0f; 105 } 106 eltwise_scalednnl::impl::gpu::ocl::ref_binary_t::pd_t107 float eltwise_scale() const { 108 const int eltwise_idx 109 = attr()->post_ops_.find(primitive_kind::eltwise); 110 return eltwise_idx != -1 111 ? attr()->post_ops_.entry_[eltwise_idx].eltwise.scale 112 : 1.0f; 113 } 114 sum_scalednnl::impl::gpu::ocl::ref_binary_t::pd_t115 float sum_scale() const { 116 const int sum_idx = attr()->post_ops_.find(primitive_kind::sum); 117 return sum_idx != -1 ? attr()->post_ops_.entry_[sum_idx].sum.scale 118 : 0.0f; 119 } 120 eltwise_alg_kinddnnl::impl::gpu::ocl::ref_binary_t::pd_t121 alg_kind_t eltwise_alg_kind() const { 122 const int eltwise_idx 123 = attr()->post_ops_.find(primitive_kind::eltwise); 124 return eltwise_idx != -1 125 ? attr()->post_ops_.entry_[eltwise_idx].eltwise.alg 126 : dnnl_alg_kind_undef; 127 } 128 129 binary_conf_t conf; 130 131 private: check_scales_maskdnnl::impl::gpu::ocl::ref_binary_t::pd_t132 bool check_scales_mask() const { 133 for (const auto &s : attr()->scales_.scales_) { 134 if (s.second.mask_ != 0) return false; 135 } 136 return true; 137 } 138 }; 139 initdnnl::impl::gpu::ocl::ref_binary_t140 status_t init(engine_t *engine) override { 141 compute::kernel_ctx_t kernel_ctx; 142 143 auto status = pd()->init_kernel_ctx(kernel_ctx); 144 if (status != status::success) return status; 145 146 create_kernel(engine, &kernel_, "ref_binary", kernel_ctx); 147 if (!kernel_) return status::runtime_error; 148 149 return status::success; 150 } 151 executednnl::impl::gpu::ocl::ref_binary_t152 status_t execute(const exec_ctx_t &ctx) const override { 153 return execute_ref(ctx); 154 } 155 156 private: 157 status_t execute_ref(const exec_ctx_t &ctx) const; pddnnl::impl::gpu::ocl::ref_binary_t158 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } 159 compute::kernel_t kernel_; 160 }; 161 162 } // namespace ocl 163 } // namespace gpu 164 } // namespace impl 165 } // namespace dnnl 166 167 #endif 168