1 /******************************************************************************* 2 * Copyright 2016-2021 Intel Corporation 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 *******************************************************************************/ 16 17 #ifndef COMMON_REORDER_PD_HPP 18 #define COMMON_REORDER_PD_HPP 19 20 #include <assert.h> 21 22 #include "c_types_map.hpp" 23 #include "engine.hpp" 24 #include "primitive.hpp" 25 #include "primitive_attr.hpp" 26 #include "primitive_desc.hpp" 27 #include "type_helpers.hpp" 28 #include "utils.hpp" 29 30 namespace dnnl { 31 namespace impl { 32 33 struct reorder_primitive_desc_iface_t : public dnnl_primitive_desc { reorder_primitive_desc_iface_tdnnl::impl::reorder_primitive_desc_iface_t34 reorder_primitive_desc_iface_t(const std::shared_ptr<primitive_desc_t> &pd, 35 engine_t *engine, engine_t *src_engine, engine_t *dst_engine) 36 : dnnl_primitive_desc(pd, engine) 37 , src_engine_(src_engine) 38 , dst_engine_(dst_engine) 39 , scratchpad_engine_(nullptr) {} 40 src_enginednnl::impl::reorder_primitive_desc_iface_t41 dnnl::impl::engine_t *src_engine() const override { return src_engine_; } dst_enginednnl::impl::reorder_primitive_desc_iface_t42 dnnl::impl::engine_t *dst_engine() const override { return dst_engine_; } 43 scratchpad_enginednnl::impl::reorder_primitive_desc_iface_t44 dnnl::impl::engine_t *scratchpad_engine() const override { 45 return scratchpad_engine_; 46 } 47 querydnnl::impl::reorder_primitive_desc_iface_t48 dnnl::impl::status_t query( 49 dnnl::impl::query_t what, int idx, void *result) const override { 50 auto status = dnnl::impl::status::success; 51 switch (what) { 52 case dnnl::impl::query::reorder_src_engine: 53 *(dnnl::impl::engine_t **)result = src_engine(); 54 break; 55 case dnnl::impl::query::reorder_dst_engine: 56 *(dnnl::impl::engine_t **)result = dst_engine(); 57 break; 58 default: status = dnnl_primitive_desc::query(what, idx, result); 59 } 60 return status; 61 } 62 create_primitive_ifacednnl::impl::reorder_primitive_desc_iface_t63 status_t create_primitive_iface( 64 std::pair<primitive_iface_t *, bool> &primitive_iface) 65 const override { 66 // Step 1: create impl::primitive_t or get it from primitive cache 67 std::pair<std::shared_ptr<primitive_t>, bool> p; 68 auto status = pd_->create_primitive(p, engine()); 69 if (status != status::success) return status; 70 // Step 2: create primitive_iface_t, init and return it to user 71 primitive_iface_t *p_iface = nullptr; 72 CHECK(safe_ptr_assign(p_iface, 73 new primitive_iface_t( 74 p.first, engine(), src_engine_, dst_engine_))); 75 status = p_iface->init(); 76 if (status != status::success) { 77 p_iface->release(); 78 return status; 79 } 80 primitive_iface = std::make_pair(p_iface, p.second); 81 return status::success; 82 } 83 84 private: 85 dnnl::impl::engine_t *src_engine_; 86 dnnl::impl::engine_t *dst_engine_; 87 dnnl::impl::engine_t *scratchpad_engine_; 88 }; 89 90 struct reorder_pd_t : public primitive_desc_t { descdnnl::impl::reorder_pd_t91 const reorder_desc_t *desc() const { return &desc_; } op_descdnnl::impl::reorder_pd_t92 const op_desc_t *op_desc() const override { 93 return reinterpret_cast<const op_desc_t *>(this->desc()); 94 } 95 arg_usagednnl::impl::reorder_pd_t96 arg_usage_t arg_usage(int arg) const override { 97 if (arg == DNNL_ARG_FROM) return arg_usage_t::input; 98 99 if (arg == DNNL_ARG_TO) return arg_usage_t::output; 100 101 return primitive_desc_t::arg_usage(arg); 102 } 103 arg_mddnnl::impl::reorder_pd_t104 const memory_desc_t *arg_md(int arg) const override { 105 switch (arg) { 106 case DNNL_ARG_FROM: return src_md(0); 107 case DNNL_ARG_TO: return dst_md(0); 108 default: return primitive_desc_t::arg_md(arg); 109 } 110 } 111 src_mddnnl::impl::reorder_pd_t112 const memory_desc_t *src_md(int index = 0) const override { 113 return index == 0 ? &src_md_ : &glob_zero_md; 114 } dst_mddnnl::impl::reorder_pd_t115 const memory_desc_t *dst_md(int index = 0) const override { 116 return index == 0 ? &dst_md_ : &glob_zero_md; 117 } 118 n_inputsdnnl::impl::reorder_pd_t119 int n_inputs() const override { return 1; } n_outputsdnnl::impl::reorder_pd_t120 int n_outputs() const override { return 1; } 121 alphadnnl::impl::reorder_pd_t122 float alpha() const { return attr()->output_scales_.scales_[0]; } betadnnl::impl::reorder_pd_t123 float beta() const { 124 const int sum_idx = attr()->post_ops_.find(primitive_kind::sum); 125 return sum_idx == -1 ? 0 : attr()->post_ops_.entry_[sum_idx].sum.scale; 126 } 127 128 protected: 129 reorder_desc_t desc_; 130 memory_desc_t src_md_; 131 memory_desc_t dst_md_; 132 reorder_pd_tdnnl::impl::reorder_pd_t133 reorder_pd_t(const primitive_attr_t *attr, engine_kind_t src_engine_kind, 134 const memory_desc_t *src_md, engine_kind_t dst_engine_kind, 135 const memory_desc_t *dst_md) 136 : primitive_desc_t(attr, primitive_kind::reorder) 137 , src_md_(*src_md) 138 , dst_md_(*dst_md) { 139 140 init_desc(src_engine_kind, dst_engine_kind, false); 141 } 142 reorder_pd_tdnnl::impl::reorder_pd_t143 reorder_pd_t(const reorder_pd_t &other) : primitive_desc_t(other) { 144 src_md_ = other.src_md_; 145 dst_md_ = other.dst_md_; 146 147 init_desc(other.desc_.src_engine_kind, other.desc_.dst_engine_kind, 148 other.desc_.is_cross_engine); 149 } 150 151 protected: init_descdnnl::impl::reorder_pd_t152 void init_desc(engine_kind_t src_engine_kind, engine_kind_t dst_engine_kind, 153 bool is_cross_engine) { 154 desc_ = reorder_desc_t(); 155 desc_.primitive_kind = primitive_kind::reorder; 156 desc_.src_md = &src_md_; 157 desc_.dst_md = &dst_md_; 158 desc_.src_engine_kind = src_engine_kind; 159 desc_.dst_engine_kind = dst_engine_kind; 160 desc_.is_cross_engine = is_cross_engine; 161 } 162 }; 163 164 } // namespace impl 165 } // namespace dnnl 166 167 #endif 168 169 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s 170