1 /******************************************************************************* 2 * Copyright 2018-2020 Intel Corporation 3 * Copyright 2020 FUJITSU LIMITED 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 *******************************************************************************/ 17 18 #ifndef CPU_AARCH64_JIT_UNI_REORDER_HPP 19 #define CPU_AARCH64_JIT_UNI_REORDER_HPP 20 21 #include <assert.h> 22 23 #include "common/c_types_map.hpp" 24 #include "common/type_helpers.hpp" 25 26 #include "cpu/reorder/cpu_reorder_pd.hpp" 27 28 namespace dnnl { 29 namespace impl { 30 namespace cpu { 31 namespace aarch64 { 32 33 namespace tr { 34 35 constexpr int max_ndims = DNNL_MAX_NDIMS; 36 37 struct node_t { 38 size_t n; 39 ptrdiff_t is; // input stride 40 ptrdiff_t os; // output stride 41 ptrdiff_t ss; // scale stride 42 }; 43 44 enum class scale_type_t { NONE, COMMON, MANY }; 45 46 struct prb_t { 47 data_type_t itype; 48 data_type_t otype; 49 int ndims; 50 node_t nodes[max_ndims]; 51 ptrdiff_t ioff; 52 ptrdiff_t ooff; 53 scale_type_t scale_type; 54 float beta; 55 }; 56 57 status_t prb_init(prb_t &prb, const memory_desc_t &imd, 58 const memory_desc_t &omd, const primitive_attr_t *attr); 59 60 /** sorts the problem nodes so that output strides come in ascending order */ 61 void prb_normalize(prb_t &p); 62 63 /** folds nodes together if possible */ 64 void prb_simplify(prb_t &p); 65 66 /** splits the node dim into two of sizes n1 and n / n1 67 * @warning n must be multiple of n1 */ 68 void prb_node_split(prb_t &p, int dim, size_t n1); 69 70 /** swaps d0 and d1 nodes */ 71 void prb_node_swap(prb_t &p, int d0, int d1); 72 73 /** moves node d0 to the d1 position. 74 * nodes (d0, d1] are shifted to the left if d0 < d1 or 75 * to the right if d0 > d1 */ 76 void prb_node_move(prb_t &p, int d0, int d1); 77 78 /** dumps the problem to stdout */ 79 void prb_dump(const prb_t &p); 80 81 struct call_param_t { 82 const void *in; 83 void *out; 84 const float *scale; 85 }; 86 87 struct kernel_t { 88 struct desc_t { 89 int id; 90 prb_t prb; 91 }; 92 kernel_tdnnl::impl::cpu::aarch64::tr::kernel_t93 kernel_t(const desc_t &desc) : desc_(desc) {} 94 virtual void operator()(const call_param_t *c) const = 0; 95 virtual status_t create_kernel() = 0; ~kernel_tdnnl::impl::cpu::aarch64::tr::kernel_t96 virtual ~kernel_t() {} 97 98 /** inits kernel descriptor: 99 * desc -- kernel descriptor (output) 100 * prb -- transposition problem (input) 101 * ndims_ker_max -- limit the maximum number of dimensions kernel 102 * will process (optional, 0 -- no limitation) */ 103 static status_t desc_init( 104 desc_t &desc, const prb_t &prb, int ndims_ker_max = 0); 105 106 /** creates kernel for the problem described in desc */ 107 static kernel_t *create(const desc_t &desc); 108 109 protected: 110 const desc_t desc_; 111 const prb_t &prb_ = desc_.prb; 112 }; 113 114 /* TODO: add trans_t class */ 115 116 } // namespace tr 117 118 struct jit_uni_reorder_t : public primitive_t { 119 using primitive_t::primitive_t; 120 struct pd_t : public cpu_reorder_pd_t { 121 using cpu_reorder_pd_t::cpu_reorder_pd_t; 122 123 DECLARE_COMMON_PD_T("jit:uni", jit_uni_reorder_t); 124 125 tr::prb_t prb_; 126 tr::kernel_t::desc_t ker_desc_; 127 int nthr_; 128 129 private: 130 static status_t create(reorder_pd_t **reorder_pd, engine_t *engine, 131 const primitive_attr_t *attr, engine_t *src_engine, 132 const memory_desc_t *src_md, engine_t *dst_engine, 133 const memory_desc_t *dst_md); 134 135 friend dnnl::impl::impl_list_item_t; 136 }; 137 138 status_t init(engine_t *engine) override; 139 status_t execute(const exec_ctx_t &ctx) const override; 140 141 enum { ndims_driver_max = 4 }; 142 143 private: 144 void omp_driver_0d( 145 int off, const char *in, char *out, const float *scale) const; 146 void omp_driver_1d(int ithr, int nthr, int off, const char *in, char *out, 147 const float *scale) const; 148 void omp_driver_2d(int ithr, int nthr, int off, const char *in, char *out, 149 const float *scale) const; 150 void omp_driver_3d(int ithr, int nthr, int off, const char *in, char *out, 151 const float *scale) const; 152 void omp_driver_4d(int ithr, int nthr, int off, const char *in, char *out, 153 const float *scale) const; 154 155 void omp_driver(const char *in, char *out, const float *scale) const; 156 pddnnl::impl::cpu::aarch64::jit_uni_reorder_t157 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } 158 std::unique_ptr<tr::kernel_t> kernel_; 159 }; 160 161 } // namespace aarch64 162 } // namespace cpu 163 } // namespace impl 164 } // namespace dnnl 165 166 #endif 167