1 /******************************************************************************* 2 * Copyright 2019-2021 Intel Corporation 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 *******************************************************************************/ 16 17 #ifndef COMMON_GEMM_TYPES_HPP 18 #define COMMON_GEMM_TYPES_HPP 19 20 #include <assert.h> 21 22 #include "oneapi/dnnl/dnnl_types.h" 23 24 namespace dnnl { 25 namespace impl { 26 27 enum transpose_t { dnnl_notrans, dnnl_trans }; 28 29 namespace transpose { 30 const transpose_t notrans = dnnl_notrans; 31 const transpose_t trans = dnnl_trans; 32 } // namespace transpose 33 34 enum offsetc_t { dnnl_fixed, dnnl_column, dnnl_row }; 35 36 namespace offsetc { 37 const offsetc_t fixed = dnnl_fixed; 38 const offsetc_t column = dnnl_column; 39 const offsetc_t row = dnnl_row; 40 } // namespace offsetc 41 42 /** A descriptor for a matrix multiplication (gemm) operation */ 43 struct dnnl_gemm_desc_t { 44 /* To make the interface consistent, the descriptor represent the 45 * GEMM operation in row major */ 46 47 /** The kind of primitive. Used for self identifying the primitive 48 * descriptor. Must be #dnnl_gemm. */ 49 dnnl_primitive_kind_t primitive_kind; 50 dnnl_memory_desc_t a_desc; 51 dnnl_memory_desc_t b_desc; 52 dnnl_memory_desc_t c_desc; 53 dnnl_memory_desc_t bias_desc; 54 /** Type for accumulating A*B. */ 55 dnnl_data_type_t acc_type; 56 57 // These accessors are to be used by the GEMM implementation 58 // Because the GEMM implementation currently assumes column major 59 // These accessors return data in column major fashion 60 is_batcheddnnl::impl::dnnl_gemm_desc_t61 inline bool is_batched() const { return c_desc.ndims >= 3; } 62 63 // Simplified accessors that comply to GEMM API get_transdnnl::impl::dnnl_gemm_desc_t64 transpose_t get_trans(dnnl_memory_desc_t md) const { 65 return md.format_desc.blocking.strides[md.ndims - 1] != 1 66 ? transpose::trans 67 : transpose::notrans; 68 } transadnnl::impl::dnnl_gemm_desc_t69 transpose_t transa() const { return get_trans(b_desc); }; transbdnnl::impl::dnnl_gemm_desc_t70 transpose_t transb() const { return get_trans(a_desc); }; batchdnnl::impl::dnnl_gemm_desc_t71 dnnl_dim_t batch() const { 72 // if ndims < 3, it should return 1 73 int64_t batch = 1; 74 for (int i = 0; i < c_desc.ndims - 2; ++i) { 75 if (c_desc.dims[i] == DNNL_RUNTIME_DIM_VAL) 76 return DNNL_RUNTIME_DIM_VAL; 77 batch *= c_desc.dims[i]; 78 } 79 return batch; 80 } 81 82 /** Number of rows of C. */ mdnnl::impl::dnnl_gemm_desc_t83 dnnl_dim_t m() const { return c_desc.dims[c_desc.ndims - 1]; } 84 /** Number of columns of C. */ ndnnl::impl::dnnl_gemm_desc_t85 dnnl_dim_t n() const { return c_desc.dims[c_desc.ndims - 2]; } 86 /** Size of inner dimension shared between A and B. */ kdnnl::impl::dnnl_gemm_desc_t87 dnnl_dim_t k() const { return a_desc.dims[a_desc.ndims - 1]; } 88 89 /** Stride between 2 matrices A in a batch. */ stride_adnnl::impl::dnnl_gemm_desc_t90 dnnl_dim_t stride_a(int dim = 0) const { 91 return (dim >= b_desc.ndims - 2 || b_desc.dims[dim] == 1) 92 ? 0 93 : b_desc.format_desc.blocking.strides[dim]; 94 }; 95 /** Stride between 2 matrices B in a batch. */ stride_bdnnl::impl::dnnl_gemm_desc_t96 dnnl_dim_t stride_b(int dim = 0) const { 97 return (dim >= a_desc.ndims - 2 || a_desc.dims[dim] == 1) 98 ? 0 99 : a_desc.format_desc.blocking.strides[dim]; 100 }; 101 /** Stride between 2 matrices C in a batch. */ stride_cdnnl::impl::dnnl_gemm_desc_t102 dnnl_dim_t stride_c(int dim = 0) const { 103 return (dim >= c_desc.ndims - 2) 104 ? 0 105 : c_desc.format_desc.blocking.strides[dim]; 106 }; 107 108 // This assumes that one of the dimensions has strides 1 get_lddnnl::impl::dnnl_gemm_desc_t109 dnnl_dim_t get_ld(dnnl_memory_desc_t md) const { 110 auto strides = md.format_desc.blocking.strides; 111 assert(strides[md.ndims - 1] == 1 || strides[md.ndims - 2] == 1); 112 return strides[md.ndims - 1] != 1 ? strides[md.ndims - 1] 113 : strides[md.ndims - 2]; 114 } 115 /** Leading dimension of A. */ ldadnnl::impl::dnnl_gemm_desc_t116 dnnl_dim_t lda() const { return get_ld(b_desc); } 117 /** Leading dimension of B. */ ldbdnnl::impl::dnnl_gemm_desc_t118 dnnl_dim_t ldb() const { return get_ld(a_desc); } 119 /** Leading dimension of C. */ ldcdnnl::impl::dnnl_gemm_desc_t120 dnnl_dim_t ldc() const { return get_ld(c_desc); } 121 122 /** Type of matrix A. */ a_typednnl::impl::dnnl_gemm_desc_t123 dnnl_data_type_t a_type() const { return b_desc.data_type; } 124 /** Type of matrix B. */ b_typednnl::impl::dnnl_gemm_desc_t125 dnnl_data_type_t b_type() const { return a_desc.data_type; } 126 /** Type of matrix C. */ c_typednnl::impl::dnnl_gemm_desc_t127 dnnl_data_type_t c_type() const { return c_desc.data_type; } 128 /** Type of bias. */ bias_typednnl::impl::dnnl_gemm_desc_t129 dnnl_data_type_t bias_type() const { return bias_desc.data_type; } 130 /** Type of bias. */ bias_maskdnnl::impl::dnnl_gemm_desc_t131 int bias_mask() const { 132 assert(bias_desc.ndims <= 3); 133 int mask = 0; 134 // TODO: update the mask for batched dimension if we start 135 // supporting more batch dimensions 136 if (is_batched()) mask |= (bias_desc.dims[0] > 1) ? 1 << 0 : 0; 137 138 // because the bias mask is in row major, we have to convert 139 // to col major here by swapping two last dimensions 140 int m_idx = is_batched(); 141 mask |= (bias_desc.dims[m_idx] > 1) ? 1 << (bias_desc.ndims - m_idx) 142 : 0; 143 mask |= (bias_desc.dims[m_idx + 1] > 1) 144 ? 1 << (bias_desc.ndims - (m_idx + 1)) 145 : 0; 146 return mask; 147 } 148 }; 149 150 } // namespace impl 151 } // namespace dnnl 152 153 #endif // COMMON_GEMM_TYPES_HPP 154