1 /*******************************************************************************
2 * Copyright 2016-2021 Intel Corporation
3 * Copyright 2020 Arm Ltd. and affiliates
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *******************************************************************************/
17 
18 #ifndef CPU_CPU_ENGINE_HPP
19 #define CPU_CPU_ENGINE_HPP
20 
21 #include <assert.h>
22 
23 #include "oneapi/dnnl/dnnl.h"
24 
25 #include "common/c_types_map.hpp"
26 #include "common/engine.hpp"
27 #include "common/engine_id.hpp"
28 #include "common/impl_list_item.hpp"
29 
30 #include "cpu/platform.hpp"
31 
32 #define CPU_INSTANCE(...) \
33     impl_list_item_t( \
34             impl_list_item_t::type_deduction_helper_t<__VA_ARGS__::pd_t>()),
35 #define CPU_INSTANCE_X64(...) DNNL_X64_ONLY(CPU_INSTANCE(__VA_ARGS__))
36 #define CPU_INSTANCE_AARCH64(...) DNNL_AARCH64_ONLY(CPU_INSTANCE(__VA_ARGS__))
37 #define CPU_INSTANCE_AARCH64_ACL(...) \
38     DNNL_AARCH64_ACL_ONLY(CPU_INSTANCE(__VA_ARGS__))
39 
40 namespace dnnl {
41 namespace impl {
42 namespace cpu {
43 
44 #define DECLARE_IMPL_LIST(kind) \
45     const impl_list_item_t *get_##kind##_impl_list(const kind##_desc_t *desc);
46 
47 DECLARE_IMPL_LIST(batch_normalization);
48 DECLARE_IMPL_LIST(binary);
49 DECLARE_IMPL_LIST(convolution);
50 DECLARE_IMPL_LIST(deconvolution);
51 DECLARE_IMPL_LIST(eltwise);
52 DECLARE_IMPL_LIST(inner_product);
53 DECLARE_IMPL_LIST(layer_normalization);
54 DECLARE_IMPL_LIST(lrn);
55 DECLARE_IMPL_LIST(logsoftmax);
56 DECLARE_IMPL_LIST(matmul);
57 DECLARE_IMPL_LIST(pooling_v2);
58 DECLARE_IMPL_LIST(prelu);
59 DECLARE_IMPL_LIST(reduction);
60 DECLARE_IMPL_LIST(resampling);
61 DECLARE_IMPL_LIST(rnn);
62 DECLARE_IMPL_LIST(shuffle);
63 DECLARE_IMPL_LIST(softmax);
64 
65 #undef DECLARE_IMPL_LIST
66 
67 class cpu_engine_impl_list_t {
68 public:
69     static const impl_list_item_t *get_concat_implementation_list();
70     static const impl_list_item_t *get_reorder_implementation_list(
71             const memory_desc_t *src_md, const memory_desc_t *dst_md);
72     static const impl_list_item_t *get_sum_implementation_list();
73 
get_implementation_list(const op_desc_t * desc)74     static const impl_list_item_t *get_implementation_list(
75             const op_desc_t *desc) {
76         static const impl_list_item_t empty_list[] = {nullptr};
77 
78 // clang-format off
79 #define CASE(kind) \
80     case primitive_kind::kind: \
81         return get_##kind##_impl_list((const kind##_desc_t *)desc);
82         switch (desc->kind) {
83             CASE(batch_normalization);
84             CASE(binary);
85             CASE(convolution);
86             CASE(deconvolution);
87             CASE(eltwise);
88             CASE(inner_product);
89             CASE(layer_normalization);
90             CASE(lrn);
91             CASE(logsoftmax);
92             CASE(matmul);
93             case primitive_kind::pooling:
94             CASE(pooling_v2);
95             CASE(prelu);
96             CASE(reduction);
97             CASE(resampling);
98             CASE(rnn);
99             CASE(shuffle);
100             CASE(softmax);
101             default: assert(!"unknown primitive kind"); return empty_list;
102         }
103 #undef CASE
104     }
105     // clang-format on
106 };
107 
108 class cpu_engine_t : public engine_t {
109 public:
cpu_engine_t()110     cpu_engine_t() : engine_t(engine_kind::cpu, get_cpu_native_runtime(), 0) {}
111 
112     /* implementation part */
113 
114     status_t create_memory_storage(memory_storage_t **storage, unsigned flags,
115             size_t size, void *handle) override;
116 
117     status_t create_stream(stream_t **stream, unsigned flags) override;
118 
119 #if DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL
120     status_t create_stream(stream_t **stream,
121             dnnl::threadpool_interop::threadpool_iface *threadpool) override;
122 #endif
123 
get_concat_implementation_list() const124     const impl_list_item_t *get_concat_implementation_list() const override {
125         return cpu_engine_impl_list_t::get_concat_implementation_list();
126     }
127 
get_reorder_implementation_list(const memory_desc_t * src_md,const memory_desc_t * dst_md) const128     const impl_list_item_t *get_reorder_implementation_list(
129             const memory_desc_t *src_md,
130             const memory_desc_t *dst_md) const override {
131         return cpu_engine_impl_list_t::get_reorder_implementation_list(
132                 src_md, dst_md);
133     }
get_sum_implementation_list() const134     const impl_list_item_t *get_sum_implementation_list() const override {
135         return cpu_engine_impl_list_t::get_sum_implementation_list();
136     }
get_implementation_list(const op_desc_t * desc) const137     const impl_list_item_t *get_implementation_list(
138             const op_desc_t *desc) const override {
139         return cpu_engine_impl_list_t::get_implementation_list(desc);
140     }
141 
device_id() const142     device_id_t device_id() const override { return std::make_tuple(0, 0, 0); }
143 
144 #ifdef DNNL_USE_RT_OBJECTS_IN_PRIMITIVE_CACHE
engine_id() const145     engine_id_t engine_id() const override {
146         // Non-sycl CPU engine doesn't have device and context.
147         return {};
148     }
149 
150 protected:
151     ~cpu_engine_t() override = default;
152 #endif
153 };
154 
155 class cpu_engine_factory_t : public engine_factory_t {
156 public:
count() const157     size_t count() const override { return 1; }
engine_create(engine_t ** engine,size_t index) const158     status_t engine_create(engine_t **engine, size_t index) const override {
159         assert(index == 0);
160         *engine = new cpu_engine_t();
161         return status::success;
162     };
163 };
164 
165 } // namespace cpu
166 } // namespace impl
167 } // namespace dnnl
168 
169 #endif
170 
171 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
172