1 /******************************************************************************* 2 * Copyright 2018-2020 Intel Corporation 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 *******************************************************************************/ 16 17 #ifndef CPU_GEMM_F32_GEMM_UTILS_F32_HPP 18 #define CPU_GEMM_F32_GEMM_UTILS_F32_HPP 19 20 #include <cstddef> 21 22 namespace dnnl { 23 namespace impl { 24 namespace cpu { 25 26 namespace gemm_utils { 27 template <typename T, bool isTransA, bool isTransB> 28 struct gemm_traits {}; 29 30 template <bool isTransA, bool isTransB> 31 struct gemm_traits<double, isTransA, isTransB> { 32 static constexpr dim_t m = 8; 33 static constexpr dim_t n = 6; 34 static constexpr dim_t BM = 4032; 35 static constexpr dim_t BN = isTransA ? 96 : 192; 36 static constexpr dim_t BK = isTransB ? 96 : 512; 37 }; 38 39 template <bool isTransA, bool isTransB> 40 struct gemm_traits<float, isTransA, isTransB> { 41 static constexpr dim_t m = 16; 42 static constexpr dim_t n = 6; 43 static constexpr dim_t BM = 4032; 44 static constexpr dim_t BN = isTransA ? 96 : 48; 45 static constexpr dim_t BK = isTransB ? 96 : 256; 46 }; 47 48 template <typename T> 49 using unroll_factor = gemm_traits<T, false, false>; 50 51 template <typename data_t> 52 void sum_two_matrices(dim_t m, dim_t n, data_t *__restrict p_src, dim_t ld_src, 53 data_t *__restrict p_dst, dim_t ld_dst); 54 55 void calc_nthr_nocopy_avx512_common(dim_t m, dim_t n, dim_t k, int nthrs, 56 int *nthrs_m, int *nthrs_n, int *nthrs_k, dim_t *BM, dim_t *BN, 57 dim_t *BK); 58 59 void calc_nthr_nocopy_avx(dim_t m, dim_t n, dim_t k, int nthrs, int *nthrs_m, 60 int *nthrs_n, int *nthrs_k, dim_t *BM, dim_t *BN, dim_t *BK); 61 62 void partition_unit_diff( 63 int ithr, int nthr, dim_t n, dim_t *t_offset, dim_t *t_block); 64 }; // namespace gemm_utils 65 66 } // namespace cpu 67 } // namespace impl 68 } // namespace dnnl 69 #endif // CPU_GEMM_F32_GEMM_UTILS_F32_HPP 70