1 /*******************************************************************************
2 * Copyright 2016-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #ifndef COMMON_REORDER_PD_HPP
18 #define COMMON_REORDER_PD_HPP
19 
20 #include <assert.h>
21 
22 #include "c_types_map.hpp"
23 #include "engine.hpp"
24 #include "primitive.hpp"
25 #include "primitive_attr.hpp"
26 #include "primitive_desc.hpp"
27 #include "type_helpers.hpp"
28 #include "utils.hpp"
29 
30 namespace dnnl {
31 namespace impl {
32 
33 struct reorder_primitive_desc_iface_t : public dnnl_primitive_desc {
reorder_primitive_desc_iface_tdnnl::impl::reorder_primitive_desc_iface_t34     reorder_primitive_desc_iface_t(const std::shared_ptr<primitive_desc_t> &pd,
35             engine_t *engine, engine_t *src_engine, engine_t *dst_engine)
36         : dnnl_primitive_desc(pd, engine)
37         , src_engine_(src_engine)
38         , dst_engine_(dst_engine)
39         , scratchpad_engine_(nullptr) {}
40 
src_enginednnl::impl::reorder_primitive_desc_iface_t41     dnnl::impl::engine_t *src_engine() const override { return src_engine_; }
dst_enginednnl::impl::reorder_primitive_desc_iface_t42     dnnl::impl::engine_t *dst_engine() const override { return dst_engine_; }
43 
scratchpad_enginednnl::impl::reorder_primitive_desc_iface_t44     dnnl::impl::engine_t *scratchpad_engine() const override {
45         return scratchpad_engine_;
46     }
47 
querydnnl::impl::reorder_primitive_desc_iface_t48     dnnl::impl::status_t query(
49             dnnl::impl::query_t what, int idx, void *result) const override {
50         auto status = dnnl::impl::status::success;
51         switch (what) {
52             case dnnl::impl::query::reorder_src_engine:
53                 *(dnnl::impl::engine_t **)result = src_engine();
54                 break;
55             case dnnl::impl::query::reorder_dst_engine:
56                 *(dnnl::impl::engine_t **)result = dst_engine();
57                 break;
58             default: status = dnnl_primitive_desc::query(what, idx, result);
59         }
60         return status;
61     }
62 
create_primitive_ifacednnl::impl::reorder_primitive_desc_iface_t63     status_t create_primitive_iface(
64             std::pair<primitive_iface_t *, bool> &primitive_iface)
65             const override {
66         // Step 1: create impl::primitive_t or get it from primitive cache
67         std::pair<std::shared_ptr<primitive_t>, bool> p;
68         auto status = pd_->create_primitive(p, engine());
69         if (status != status::success) return status;
70         // Step 2: create primitive_iface_t, init and return it to user
71         primitive_iface_t *p_iface = nullptr;
72         CHECK(safe_ptr_assign(p_iface,
73                 new primitive_iface_t(
74                         p.first, engine(), src_engine_, dst_engine_)));
75         status = p_iface->init();
76         if (status != status::success) {
77             p_iface->release();
78             return status;
79         }
80         primitive_iface = std::make_pair(p_iface, p.second);
81         return status::success;
82     }
83 
84 private:
85     dnnl::impl::engine_t *src_engine_;
86     dnnl::impl::engine_t *dst_engine_;
87     dnnl::impl::engine_t *scratchpad_engine_;
88 };
89 
90 struct reorder_pd_t : public primitive_desc_t {
descdnnl::impl::reorder_pd_t91     const reorder_desc_t *desc() const { return &desc_; }
op_descdnnl::impl::reorder_pd_t92     const op_desc_t *op_desc() const override {
93         return reinterpret_cast<const op_desc_t *>(this->desc());
94     }
95 
arg_usagednnl::impl::reorder_pd_t96     arg_usage_t arg_usage(int arg) const override {
97         if (arg == DNNL_ARG_FROM) return arg_usage_t::input;
98 
99         if (arg == DNNL_ARG_TO) return arg_usage_t::output;
100 
101         return primitive_desc_t::arg_usage(arg);
102     }
103 
arg_mddnnl::impl::reorder_pd_t104     const memory_desc_t *arg_md(int arg) const override {
105         switch (arg) {
106             case DNNL_ARG_FROM: return src_md(0);
107             case DNNL_ARG_TO: return dst_md(0);
108             default: return primitive_desc_t::arg_md(arg);
109         }
110     }
111 
src_mddnnl::impl::reorder_pd_t112     const memory_desc_t *src_md(int index = 0) const override {
113         return index == 0 ? &src_md_ : &glob_zero_md;
114     }
dst_mddnnl::impl::reorder_pd_t115     const memory_desc_t *dst_md(int index = 0) const override {
116         return index == 0 ? &dst_md_ : &glob_zero_md;
117     }
118 
n_inputsdnnl::impl::reorder_pd_t119     int n_inputs() const override { return 1; }
n_outputsdnnl::impl::reorder_pd_t120     int n_outputs() const override { return 1; }
121 
alphadnnl::impl::reorder_pd_t122     float alpha() const { return attr()->output_scales_.scales_[0]; }
betadnnl::impl::reorder_pd_t123     float beta() const {
124         const int sum_idx = attr()->post_ops_.find(primitive_kind::sum);
125         return sum_idx == -1 ? 0 : attr()->post_ops_.entry_[sum_idx].sum.scale;
126     }
127 
128 protected:
129     reorder_desc_t desc_;
130     memory_desc_t src_md_;
131     memory_desc_t dst_md_;
132 
reorder_pd_tdnnl::impl::reorder_pd_t133     reorder_pd_t(const primitive_attr_t *attr, engine_kind_t src_engine_kind,
134             const memory_desc_t *src_md, engine_kind_t dst_engine_kind,
135             const memory_desc_t *dst_md)
136         : primitive_desc_t(attr, primitive_kind::reorder)
137         , src_md_(*src_md)
138         , dst_md_(*dst_md) {
139 
140         init_desc(src_engine_kind, dst_engine_kind, false);
141     }
142 
reorder_pd_tdnnl::impl::reorder_pd_t143     reorder_pd_t(const reorder_pd_t &other) : primitive_desc_t(other) {
144         src_md_ = other.src_md_;
145         dst_md_ = other.dst_md_;
146 
147         init_desc(other.desc_.src_engine_kind, other.desc_.dst_engine_kind,
148                 other.desc_.is_cross_engine);
149     }
150 
151 protected:
init_descdnnl::impl::reorder_pd_t152     void init_desc(engine_kind_t src_engine_kind, engine_kind_t dst_engine_kind,
153             bool is_cross_engine) {
154         desc_ = reorder_desc_t();
155         desc_.primitive_kind = primitive_kind::reorder;
156         desc_.src_md = &src_md_;
157         desc_.dst_md = &dst_md_;
158         desc_.src_engine_kind = src_engine_kind;
159         desc_.dst_engine_kind = dst_engine_kind;
160         desc_.is_cross_engine = is_cross_engine;
161     }
162 };
163 
164 } // namespace impl
165 } // namespace dnnl
166 
167 #endif
168 
169 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
170