1 /*******************************************************************************
2 * Copyright 2018-2020 Intel Corporation
3 * Copyright 2020 FUJITSU LIMITED
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *******************************************************************************/
17 
18 #ifndef CPU_AARCH64_JIT_UNI_REORDER_HPP
19 #define CPU_AARCH64_JIT_UNI_REORDER_HPP
20 
21 #include <assert.h>
22 
23 #include "common/c_types_map.hpp"
24 #include "common/type_helpers.hpp"
25 
26 #include "cpu/reorder/cpu_reorder_pd.hpp"
27 
28 namespace dnnl {
29 namespace impl {
30 namespace cpu {
31 namespace aarch64 {
32 
33 namespace tr {
34 
35 constexpr int max_ndims = DNNL_MAX_NDIMS;
36 
37 struct node_t {
38     size_t n;
39     ptrdiff_t is; // input stride
40     ptrdiff_t os; // output stride
41     ptrdiff_t ss; // scale stride
42 };
43 
44 enum class scale_type_t { NONE, COMMON, MANY };
45 
46 struct prb_t {
47     data_type_t itype;
48     data_type_t otype;
49     int ndims;
50     node_t nodes[max_ndims];
51     ptrdiff_t ioff;
52     ptrdiff_t ooff;
53     scale_type_t scale_type;
54     float beta;
55 };
56 
57 status_t prb_init(prb_t &prb, const memory_desc_t &imd,
58         const memory_desc_t &omd, const primitive_attr_t *attr);
59 
60 /** sorts the problem nodes so that output strides come in ascending order */
61 void prb_normalize(prb_t &p);
62 
63 /** folds nodes together if possible */
64 void prb_simplify(prb_t &p);
65 
66 /** splits the node dim into two of sizes n1 and n / n1
67  * @warning n must be multiple of n1 */
68 void prb_node_split(prb_t &p, int dim, size_t n1);
69 
70 /** swaps d0 and d1 nodes */
71 void prb_node_swap(prb_t &p, int d0, int d1);
72 
73 /** moves node d0 to the d1 position.
74  * nodes (d0, d1] are shifted to the left if d0 < d1 or
75  * to the right if d0 > d1 */
76 void prb_node_move(prb_t &p, int d0, int d1);
77 
78 /** dumps the problem to stdout */
79 void prb_dump(const prb_t &p);
80 
81 struct call_param_t {
82     const void *in;
83     void *out;
84     const float *scale;
85 };
86 
87 struct kernel_t {
88     struct desc_t {
89         int id;
90         prb_t prb;
91     };
92 
kernel_tdnnl::impl::cpu::aarch64::tr::kernel_t93     kernel_t(const desc_t &desc) : desc_(desc) {}
94     virtual void operator()(const call_param_t *c) const = 0;
95     virtual status_t create_kernel() = 0;
~kernel_tdnnl::impl::cpu::aarch64::tr::kernel_t96     virtual ~kernel_t() {}
97 
98     /** inits kernel descriptor:
99      *      desc            -- kernel descriptor (output)
100      *      prb             -- transposition problem (input)
101      *      ndims_ker_max   -- limit the maximum number of dimensions kernel
102      *                         will process (optional, 0 -- no limitation) */
103     static status_t desc_init(
104             desc_t &desc, const prb_t &prb, int ndims_ker_max = 0);
105 
106     /** creates kernel for the problem described in desc */
107     static kernel_t *create(const desc_t &desc);
108 
109 protected:
110     const desc_t desc_;
111     const prb_t &prb_ = desc_.prb;
112 };
113 
114 /* TODO: add trans_t class */
115 
116 } // namespace tr
117 
118 struct jit_uni_reorder_t : public primitive_t {
119     using primitive_t::primitive_t;
120     struct pd_t : public cpu_reorder_pd_t {
121         using cpu_reorder_pd_t::cpu_reorder_pd_t;
122 
123         DECLARE_COMMON_PD_T("jit:uni", jit_uni_reorder_t);
124 
125         tr::prb_t prb_;
126         tr::kernel_t::desc_t ker_desc_;
127         int nthr_;
128 
129     private:
130         static status_t create(reorder_pd_t **reorder_pd, engine_t *engine,
131                 const primitive_attr_t *attr, engine_t *src_engine,
132                 const memory_desc_t *src_md, engine_t *dst_engine,
133                 const memory_desc_t *dst_md);
134 
135         friend dnnl::impl::impl_list_item_t;
136     };
137 
138     status_t init(engine_t *engine) override;
139     status_t execute(const exec_ctx_t &ctx) const override;
140 
141     enum { ndims_driver_max = 4 };
142 
143 private:
144     void omp_driver_0d(
145             int off, const char *in, char *out, const float *scale) const;
146     void omp_driver_1d(int ithr, int nthr, int off, const char *in, char *out,
147             const float *scale) const;
148     void omp_driver_2d(int ithr, int nthr, int off, const char *in, char *out,
149             const float *scale) const;
150     void omp_driver_3d(int ithr, int nthr, int off, const char *in, char *out,
151             const float *scale) const;
152     void omp_driver_4d(int ithr, int nthr, int off, const char *in, char *out,
153             const float *scale) const;
154 
155     void omp_driver(const char *in, char *out, const float *scale) const;
156 
pddnnl::impl::cpu::aarch64::jit_uni_reorder_t157     const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
158     std::unique_ptr<tr::kernel_t> kernel_;
159 };
160 
161 } // namespace aarch64
162 } // namespace cpu
163 } // namespace impl
164 } // namespace dnnl
165 
166 #endif
167