1 /****************************************************************************** 2 * Copyright (c) Intel Corporation - All rights reserved. * 3 * This file is part of the LIBXSMM library. * 4 * * 5 * For information on the license, see the LICENSE file. * 6 * Further information: https://github.com/hfp/libxsmm/ * 7 * SPDX-License-Identifier: BSD-3-Clause * 8 ******************************************************************************/ 9 /* Nadathur Satish (Intel Corp.) 10 ******************************************************************************/ 11 #ifndef LIBXSMM_SPMDM_H 12 #define LIBXSMM_SPMDM_H 13 14 #include "libxsmm_typedefs.h" 15 16 17 typedef enum libxsmm_spmdm_datatype { 18 LIBXSMM_SPMDM_DATATYPE_F32, 19 LIBXSMM_SPMDM_DATATYPE_BFLOAT16 20 } libxsmm_spmdm_datatype; 21 22 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_spmdm_handle { 23 /* The following are the matrix multiply dimensions: A (sparse): m X k, B (dense): k X n, Output C (dense): m X n */ 24 int m; 25 int n; 26 int k; 27 /* The block sizes for A, B and C. */ 28 /* Here we fix A to be divided into 128 X 128 blocks, B/C to be 128 X 48 for HSW/BDW and 128 X 96 for SKX */ 29 int bm; 30 int bn; 31 int bk; 32 /* The number of blocks for the m, n and k dimensions */ 33 int mb; 34 int nb; 35 int kb; 36 libxsmm_spmdm_datatype datatype; 37 char* base_ptr_scratch_A; 38 char* base_ptr_scratch_B_scratch_C; 39 int memory_for_scratch_per_thread; 40 } libxsmm_spmdm_handle; 41 42 /** 43 * This stores a single sparse splice (or block) of sparse matrix A using a CSR representation (rowidx, colidx, and values 44 * Each splice corresponds to a bm X bk region of A, and stores local indexes 45 */ 46 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_CSR_sparseslice { 47 /* Since bm and bk are assumed to be <=256, a 16-bit integer is enough to store the local rowidx, colidx */ 48 uint16_t* rowidx; 49 uint16_t* colidx; 50 float* values; 51 } libxsmm_CSR_sparseslice; 52 53 54 LIBXSMM_API void libxsmm_spmdm_init( 55 int M, int N, int K, 56 int max_threads, 57 libxsmm_spmdm_handle* handle, 58 libxsmm_CSR_sparseslice** libxsmm_output_csr); 59 60 LIBXSMM_API void libxsmm_spmdm_destroy( 61 libxsmm_spmdm_handle* handle); 62 63 LIBXSMM_API int libxsmm_spmdm_get_num_createSparseSlice_blocks( 64 const libxsmm_spmdm_handle* handle); 65 66 LIBXSMM_API int libxsmm_spmdm_get_num_compute_blocks( 67 const libxsmm_spmdm_handle* handle); 68 69 /** This converts a dense representation of the sparse matrix to 2D array of sparse slices. */ 70 LIBXSMM_API void libxsmm_spmdm_createSparseSlice_fp32_thread( 71 const libxsmm_spmdm_handle* handle, 72 char transa, 73 const float* a, 74 libxsmm_CSR_sparseslice* libxsmm_output_csr_a, 75 int block_id, 76 int tid, int nthreads); 77 78 LIBXSMM_API void libxsmm_spmdm_createSparseSlice_bfloat16_thread( 79 const libxsmm_spmdm_handle* handle, 80 char transa, 81 const libxsmm_bfloat16* a, 82 libxsmm_CSR_sparseslice* libxsmm_output_csr_a, 83 int block_id, 84 int tid, int nthreads); 85 86 /** NOTE: This code currently ignores alpha input to the matrix multiply */ 87 LIBXSMM_API void libxsmm_spmdm_compute_fp32_thread( 88 const libxsmm_spmdm_handle* handle, 89 char transa, 90 char transb, 91 const float* alpha, 92 libxsmm_CSR_sparseslice* a_sparse, 93 const float* b, 94 char transc, 95 const float* beta, 96 float* c, 97 int block_id, 98 int tid, int nthreads); 99 100 /** NOTE: This code currently ignores alpha input to the matrix multiply */ 101 LIBXSMM_API void libxsmm_spmdm_compute_bfloat16_thread( 102 const libxsmm_spmdm_handle* handle, 103 char transa, 104 char transb, 105 const libxsmm_bfloat16* alpha, 106 libxsmm_CSR_sparseslice* a_sparse, 107 const libxsmm_bfloat16* b, 108 char transc, 109 const libxsmm_bfloat16* beta, 110 float* c, 111 int block_id, 112 int tid, int nthreads); 113 114 #endif /*LIBXSMM_SPMDM_H*/ 115 116