1 /*******************************************************************************
2 * Copyright 2016-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include <assert.h>
18 #include "oneapi/dnnl/dnnl.h"
19 
20 #include "c_types_map.hpp"
21 #include "engine.hpp"
22 #include "impl_list_item.hpp"
23 #include "primitive_cache.hpp"
24 #include "primitive_hashing.hpp"
25 #include "type_helpers.hpp"
26 #include "utils.hpp"
27 
28 #include "reorder_pd.hpp"
29 
30 using namespace dnnl::impl;
31 using namespace dnnl::impl::utils;
32 using namespace dnnl::impl::status;
33 
34 namespace dnnl {
35 namespace impl {
36 
37 namespace {
get_reorder_engine(engine_t * src_engine,engine_t * dst_engine)38 engine_t *get_reorder_engine(engine_t *src_engine, engine_t *dst_engine) {
39     auto s_ek = src_engine->kind();
40     auto d_ek = dst_engine->kind();
41     auto s_rk = src_engine->runtime_kind();
42     auto d_rk = dst_engine->runtime_kind();
43 
44     if (is_native_runtime(d_rk)) return src_engine;
45 
46     if (is_native_runtime(s_rk)) return dst_engine;
47 
48     if (d_ek == engine_kind::cpu) return src_engine;
49 
50     if (s_ek == engine_kind::cpu) return dst_engine;
51 
52     assert(s_ek == engine_kind::gpu);
53     assert(d_ek == engine_kind::gpu);
54     return src_engine;
55 }
56 } // namespace
57 
reorder_primitive_desc_create(std::shared_ptr<primitive_desc_t> & pd,engine_t * engine,const memory_desc_t * src_md,engine_t * src_engine,const memory_desc_t * dst_md,engine_t * dst_engine,const primitive_attr_t * attr)58 status_t reorder_primitive_desc_create(std::shared_ptr<primitive_desc_t> &pd,
59         engine_t *engine, const memory_desc_t *src_md, engine_t *src_engine,
60         const memory_desc_t *dst_md, engine_t *dst_engine,
61         const primitive_attr_t *attr) {
62     pd.reset();
63 
64     auto s_ek = src_engine->kind();
65     auto d_ek = dst_engine->kind();
66     if (!IMPLICATION(s_ek != d_ek, utils::one_of(engine_kind::cpu, s_ek, d_ek)))
67         return invalid_arguments;
68 
69     auto s_mdw = memory_desc_wrapper(*src_md);
70     auto d_mdw = memory_desc_wrapper(*dst_md);
71 
72     if (!s_mdw.consistent_with(d_mdw)) return invalid_arguments;
73 
74     if (attr == nullptr) attr = &default_attr();
75 
76     bool is_cross_engine = src_engine != dst_engine
77             && utils::one_of(
78                     engine_kind::gpu, src_engine->kind(), dst_engine->kind());
79 
80     dnnl_reorder_desc_t desc = {primitive_kind::reorder, src_md, dst_md, s_ek,
81             d_ek, is_cross_engine};
82     primitive_hashing::key_t key(
83             engine, reinterpret_cast<op_desc_t *>(&desc), attr, 0, {});
84     pd = primitive_cache().get_pd(key);
85     if (pd) return success;
86 
87     for (auto r = engine->get_reorder_implementation_list(src_md, dst_md); *r;
88             ++r) {
89         reorder_pd_t *reorder_pd = nullptr;
90         if ((*r)(&reorder_pd, engine, attr, src_engine, src_md, dst_engine,
91                     dst_md)
92                 == success) {
93             pd.reset(reorder_pd);
94             return success;
95         }
96     }
97     return unimplemented;
98 }
99 
reorder_primitive_desc_create(std::shared_ptr<primitive_desc_t> & pd,engine_t * engine,const memory_desc_t * src_md,const memory_desc_t * dst_md,const primitive_attr_t * attr)100 status_t reorder_primitive_desc_create(std::shared_ptr<primitive_desc_t> &pd,
101         engine_t *engine, const memory_desc_t *src_md,
102         const memory_desc_t *dst_md, const primitive_attr_t *attr) {
103     return reorder_primitive_desc_create(
104             pd, engine, src_md, engine, dst_md, engine, attr);
105 }
106 
107 } // namespace impl
108 } // namespace dnnl
109 
dnnl_reorder_primitive_desc_create(primitive_desc_iface_t ** reorder_pd_iface,const memory_desc_t * src_md,engine_t * src_engine,const memory_desc_t * dst_md,engine_t * dst_engine,const primitive_attr_t * attr)110 status_t dnnl_reorder_primitive_desc_create(
111         primitive_desc_iface_t **reorder_pd_iface, const memory_desc_t *src_md,
112         engine_t *src_engine, const memory_desc_t *dst_md, engine_t *dst_engine,
113         const primitive_attr_t *attr) {
114     if (any_null(reorder_pd_iface, src_engine, src_md, dst_engine, dst_md))
115         return invalid_arguments;
116 
117     std::shared_ptr<primitive_desc_t> pd;
118     auto e = get_reorder_engine(src_engine, dst_engine);
119     CHECK(reorder_primitive_desc_create(
120             pd, e, src_md, src_engine, dst_md, dst_engine, attr));
121 
122     return safe_ptr_assign(*reorder_pd_iface,
123             new reorder_primitive_desc_iface_t(pd, e, src_engine, dst_engine));
124 }
125 
126 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
127