1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved.                      *
3 * This file is part of the LIBXSMM library.                                   *
4 *                                                                             *
5 * For information on the license, see the LICENSE file.                       *
6 * Further information: https://github.com/hfp/libxsmm/                        *
7 * SPDX-License-Identifier: BSD-3-Clause                                       *
8 ******************************************************************************/
9 /* Nadathur Satish (Intel Corp.)
10 ******************************************************************************/
11 #ifndef LIBXSMM_SPMDM_H
12 #define LIBXSMM_SPMDM_H
13 
14 #include "libxsmm_typedefs.h"
15 
16 
17 typedef enum libxsmm_spmdm_datatype {
18   LIBXSMM_SPMDM_DATATYPE_F32,
19   LIBXSMM_SPMDM_DATATYPE_BFLOAT16
20 } libxsmm_spmdm_datatype;
21 
22 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_spmdm_handle {
23   /* The following are the matrix multiply dimensions: A (sparse): m X k, B (dense): k X n, Output C (dense): m X n */
24   int m;
25   int n;
26   int k;
27   /* The block sizes for A, B and C. */
28   /* Here we fix A to be divided into 128 X 128 blocks, B/C to be 128 X 48 for HSW/BDW and 128 X 96 for SKX */
29   int bm;
30   int bn;
31   int bk;
32   /* The number of blocks for the m, n and k dimensions */
33   int mb;
34   int nb;
35   int kb;
36   libxsmm_spmdm_datatype datatype;
37   char* base_ptr_scratch_A;
38   char* base_ptr_scratch_B_scratch_C;
39   int memory_for_scratch_per_thread;
40 } libxsmm_spmdm_handle;
41 
42 /**
43  * This stores a single sparse splice (or block) of sparse matrix A using a CSR representation (rowidx, colidx, and values
44  * Each splice corresponds to a bm X bk region of A, and stores local indexes
45  */
46 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_CSR_sparseslice {
47   /* Since bm and bk are assumed to be <=256, a 16-bit integer is enough to store the local rowidx, colidx */
48   uint16_t* rowidx;
49   uint16_t* colidx;
50   float*    values;
51 } libxsmm_CSR_sparseslice;
52 
53 
54 LIBXSMM_API void libxsmm_spmdm_init(
55   int M, int N, int K,
56   int max_threads,
57   libxsmm_spmdm_handle* handle,
58   libxsmm_CSR_sparseslice** libxsmm_output_csr);
59 
60 LIBXSMM_API void libxsmm_spmdm_destroy(
61   libxsmm_spmdm_handle* handle);
62 
63 LIBXSMM_API int libxsmm_spmdm_get_num_createSparseSlice_blocks(
64   const libxsmm_spmdm_handle* handle);
65 
66 LIBXSMM_API int libxsmm_spmdm_get_num_compute_blocks(
67   const libxsmm_spmdm_handle* handle);
68 
69 /** This converts a dense representation of the sparse matrix to 2D array of sparse slices. */
70 LIBXSMM_API void libxsmm_spmdm_createSparseSlice_fp32_thread(
71   const libxsmm_spmdm_handle* handle,
72   char transa,
73   const float* a,
74   libxsmm_CSR_sparseslice* libxsmm_output_csr_a,
75   int block_id,
76   int tid, int nthreads);
77 
78 LIBXSMM_API void libxsmm_spmdm_createSparseSlice_bfloat16_thread(
79   const libxsmm_spmdm_handle* handle,
80   char transa,
81   const libxsmm_bfloat16* a,
82   libxsmm_CSR_sparseslice* libxsmm_output_csr_a,
83   int block_id,
84   int tid, int nthreads);
85 
86 /** NOTE: This code currently ignores alpha input to the matrix multiply */
87 LIBXSMM_API void libxsmm_spmdm_compute_fp32_thread(
88   const libxsmm_spmdm_handle* handle,
89   char transa,
90   char transb,
91   const float* alpha,
92   libxsmm_CSR_sparseslice* a_sparse,
93   const float* b,
94   char transc,
95   const float* beta,
96   float* c,
97   int block_id,
98   int tid, int nthreads);
99 
100 /** NOTE: This code currently ignores alpha input to the matrix multiply */
101 LIBXSMM_API void libxsmm_spmdm_compute_bfloat16_thread(
102   const libxsmm_spmdm_handle* handle,
103   char transa,
104   char transb,
105   const libxsmm_bfloat16* alpha,
106   libxsmm_CSR_sparseslice* a_sparse,
107   const libxsmm_bfloat16* b,
108   char transc,
109   const libxsmm_bfloat16* beta,
110   float* c,
111   int block_id,
112   int tid, int nthreads);
113 
114 #endif /*LIBXSMM_SPMDM_H*/
115 
116