1 /*******************************************************************************
2 * Copyright 2016-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <assert.h>
18 #include "oneapi/dnnl/dnnl.h"
19
20 #include "c_types_map.hpp"
21 #include "engine.hpp"
22 #include "impl_list_item.hpp"
23 #include "primitive_cache.hpp"
24 #include "primitive_hashing.hpp"
25 #include "type_helpers.hpp"
26 #include "utils.hpp"
27
28 #include "reorder_pd.hpp"
29
30 using namespace dnnl::impl;
31 using namespace dnnl::impl::utils;
32 using namespace dnnl::impl::status;
33
34 namespace dnnl {
35 namespace impl {
36
37 namespace {
get_reorder_engine(engine_t * src_engine,engine_t * dst_engine)38 engine_t *get_reorder_engine(engine_t *src_engine, engine_t *dst_engine) {
39 auto s_ek = src_engine->kind();
40 auto d_ek = dst_engine->kind();
41 auto s_rk = src_engine->runtime_kind();
42 auto d_rk = dst_engine->runtime_kind();
43
44 if (is_native_runtime(d_rk)) return src_engine;
45
46 if (is_native_runtime(s_rk)) return dst_engine;
47
48 if (d_ek == engine_kind::cpu) return src_engine;
49
50 if (s_ek == engine_kind::cpu) return dst_engine;
51
52 assert(s_ek == engine_kind::gpu);
53 assert(d_ek == engine_kind::gpu);
54 return src_engine;
55 }
56 } // namespace
57
reorder_primitive_desc_create(std::shared_ptr<primitive_desc_t> & pd,engine_t * engine,const memory_desc_t * src_md,engine_t * src_engine,const memory_desc_t * dst_md,engine_t * dst_engine,const primitive_attr_t * attr)58 status_t reorder_primitive_desc_create(std::shared_ptr<primitive_desc_t> &pd,
59 engine_t *engine, const memory_desc_t *src_md, engine_t *src_engine,
60 const memory_desc_t *dst_md, engine_t *dst_engine,
61 const primitive_attr_t *attr) {
62 pd.reset();
63
64 auto s_ek = src_engine->kind();
65 auto d_ek = dst_engine->kind();
66 if (!IMPLICATION(s_ek != d_ek, utils::one_of(engine_kind::cpu, s_ek, d_ek)))
67 return invalid_arguments;
68
69 auto s_mdw = memory_desc_wrapper(*src_md);
70 auto d_mdw = memory_desc_wrapper(*dst_md);
71
72 if (!s_mdw.consistent_with(d_mdw)) return invalid_arguments;
73
74 if (attr == nullptr) attr = &default_attr();
75
76 bool is_cross_engine = src_engine != dst_engine
77 && utils::one_of(
78 engine_kind::gpu, src_engine->kind(), dst_engine->kind());
79
80 dnnl_reorder_desc_t desc = {primitive_kind::reorder, src_md, dst_md, s_ek,
81 d_ek, is_cross_engine};
82 primitive_hashing::key_t key(
83 engine, reinterpret_cast<op_desc_t *>(&desc), attr, 0, {});
84 pd = primitive_cache().get_pd(key);
85 if (pd) return success;
86
87 for (auto r = engine->get_reorder_implementation_list(src_md, dst_md); *r;
88 ++r) {
89 reorder_pd_t *reorder_pd = nullptr;
90 if ((*r)(&reorder_pd, engine, attr, src_engine, src_md, dst_engine,
91 dst_md)
92 == success) {
93 pd.reset(reorder_pd);
94 return success;
95 }
96 }
97 return unimplemented;
98 }
99
reorder_primitive_desc_create(std::shared_ptr<primitive_desc_t> & pd,engine_t * engine,const memory_desc_t * src_md,const memory_desc_t * dst_md,const primitive_attr_t * attr)100 status_t reorder_primitive_desc_create(std::shared_ptr<primitive_desc_t> &pd,
101 engine_t *engine, const memory_desc_t *src_md,
102 const memory_desc_t *dst_md, const primitive_attr_t *attr) {
103 return reorder_primitive_desc_create(
104 pd, engine, src_md, engine, dst_md, engine, attr);
105 }
106
107 } // namespace impl
108 } // namespace dnnl
109
dnnl_reorder_primitive_desc_create(primitive_desc_iface_t ** reorder_pd_iface,const memory_desc_t * src_md,engine_t * src_engine,const memory_desc_t * dst_md,engine_t * dst_engine,const primitive_attr_t * attr)110 status_t dnnl_reorder_primitive_desc_create(
111 primitive_desc_iface_t **reorder_pd_iface, const memory_desc_t *src_md,
112 engine_t *src_engine, const memory_desc_t *dst_md, engine_t *dst_engine,
113 const primitive_attr_t *attr) {
114 if (any_null(reorder_pd_iface, src_engine, src_md, dst_engine, dst_md))
115 return invalid_arguments;
116
117 std::shared_ptr<primitive_desc_t> pd;
118 auto e = get_reorder_engine(src_engine, dst_engine);
119 CHECK(reorder_primitive_desc_create(
120 pd, e, src_md, src_engine, dst_md, dst_engine, attr));
121
122 return safe_ptr_assign(*reorder_pd_iface,
123 new reorder_primitive_desc_iface_t(pd, e, src_engine, dst_engine));
124 }
125
126 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
127