1 /****************************************************************************** 2 * Copyright (c) Intel Corporation - All rights reserved. * 3 * This file is part of the LIBXSMM library. * 4 * * 5 * For information on the license, see the LICENSE file. * 6 * Further information: https://github.com/hfp/libxsmm/ * 7 * SPDX-License-Identifier: BSD-3-Clause * 8 ******************************************************************************/ 9 /* Nadathur Satish (Intel Corp.) 10 ******************************************************************************/ 11 12 const int m_blocks = handle->mb; 13 /*const int n_blocks = handle->nb;*/ 14 const int k_blocks = handle->kb; 15 const int m_block_size = handle->bm; 16 const int n_block_size = handle->bn; 17 const int k_block_size = handle->bk; 18 int mb = block_id / handle->nb; 19 int nb = block_id % handle->nb; 20 21 22 #define LIBXSMM_SPMDM_COMPUTE_NREGS (6) 23 int m_overall_start = mb*m_block_size; 24 int m_overall_end = (mb + 1)*m_block_size; 25 int num_m; 26 int num_m_aligned; 27 28 int n_overall_start = nb*n_block_size; 29 int n_overall_end = (nb + 1)*n_block_size; 30 int num_n; 31 int m, n, k, kb; 32 int last_block_n, num_full_regs, last_n_start; 33 34 int k_overall_start, k_overall_end, num_k; 35 36 float *const scratch_C = (float *)(handle->base_ptr_scratch_B_scratch_C + (size_t)tid*handle->memory_for_scratch_per_thread); 37 float *const scratch_B = (float *)(handle->base_ptr_scratch_B_scratch_C + (size_t)tid*handle->memory_for_scratch_per_thread + (size_t)m_block_size*n_block_size*sizeof(float)); 38 #if 0 39 float *const scratch_C = (float *)(handle->spmdm_scratch_C + tid*m_block_size*n_block_size*sizeof(float)); 40 float *const scratch_B = (float *)(handle->spmdm_scratch_B + tid*k_block_size*n_block_size*sizeof(float)); 41 #endif 42 43 SIMDTYPE_FP32 sum[2*LIBXSMM_SPMDM_COMPUTE_NREGS]; 44 float* LIBXSMM_RESTRICT ptr_result; 45 #if SIMD_WIDTH_FP32 > 1 46 SIMDTYPE_INT32 vzero = _MM_SETZERO_INT32(); 47 #endif 48 49 LIBXSMM_UNUSED(nthreads); 50 LIBXSMM_UNUSED(transa); 51 LIBXSMM_UNUSED(alpha); 52 LIBXSMM_UNUSED(beta); 53 LIBXSMM_UNUSED(tid); 54 55 /* really is twice this */ 56 assert(n_block_size == LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32); 57 58 if (m_overall_end > handle->m) m_overall_end = handle->m; 59 num_m = (m_overall_end - m_overall_start); 60 num_m_aligned = (num_m / 2) * 2; 61 62 if (n_overall_end > handle->n) n_overall_end = handle->n; 63 num_n = (n_overall_end - n_overall_start); 64 last_block_n = (num_n != n_block_size); 65 num_full_regs = (num_n / SIMD_WIDTH_FP32); 66 if ((num_full_regs > 0) && (num_full_regs%2)) num_full_regs--; 67 last_n_start = num_full_regs*SIMD_WIDTH_FP32; 68 69 /* Copy in c matrix to buffer */ 70 ptr_result = c + (size_t)m_overall_start*handle->n + n_overall_start; 71 if (LIBXSMM_FEQ(0.f, *beta)) { 72 if (!last_block_n) { 73 for (m = 0; m < num_m; m++) { 74 _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 0*SIMD_WIDTH_FP32, _MM_SETZERO_FP32()); 75 _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 1*SIMD_WIDTH_FP32, _MM_SETZERO_FP32()); 76 _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 2*SIMD_WIDTH_FP32, _MM_SETZERO_FP32()); 77 _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 3*SIMD_WIDTH_FP32, _MM_SETZERO_FP32()); 78 _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 4*SIMD_WIDTH_FP32, _MM_SETZERO_FP32()); 79 _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 5*SIMD_WIDTH_FP32, _MM_SETZERO_FP32()); 80 } 81 } else { 82 for (m = 0; m < num_m; m++) { 83 for (n = 0; n < num_full_regs; n += 2) { 84 _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + ((size_t)n)*SIMD_WIDTH_FP32, _MM_SETZERO_FP32()); 85 _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + ((size_t)n+1)*SIMD_WIDTH_FP32, _MM_SETZERO_FP32()); 86 } 87 for (n = last_n_start; n < num_n; n++) { 88 scratch_C[m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + n] = 0; 89 } 90 } 91 } 92 } 93 else if (LIBXSMM_FEQ(1.f, *beta)) { 94 if ('T' == transc || 't' == transc) { 95 int num_m_simd = num_m / SIMD_WIDTH_FP32 * SIMD_WIDTH_FP32; 96 int num_n_simd = num_n / SIMD_WIDTH_FP32 * SIMD_WIDTH_FP32; 97 int m2; 98 99 ptr_result = c + (size_t)n_overall_start*handle->m + m_overall_start; 100 101 for (m = 0; m < num_m_simd; m += SIMD_WIDTH_FP32) { 102 for (n = 0; n < num_n_simd; n += SIMD_WIDTH_FP32) { 103 TRANSPOSE_SIMD_WIDTH_KERNEL(ptr_result + (size_t)n*handle->m + m, handle->m, scratch_C + (size_t)m*n_block_size + n, n_block_size); 104 } 105 /* Transpose a SIMD_WIDTH_FP32 * (num_n - num_n_simd) block of output space - input is of size (num_n - num_n_simd) * SIMD_WIDTH_FP32 */ 106 for (m2 = m; m2 < m + SIMD_WIDTH_FP32; m2++) { 107 for (n = num_n_simd; n < num_n; n++) { 108 scratch_C[m2*n_block_size + n] = ptr_result[n*handle->m + m2]; 109 } 110 } 111 } 112 /* Transpose a (num_m - num_m_simd) * num_n block of output space - input is of size num_n * (num_m - num_m_simd) */ 113 for (m = num_m_simd; m < num_m; m++) { 114 for (n = 0; n < num_n; n++) { 115 scratch_C[m*n_block_size + n] = ptr_result[n*handle->m + m]; 116 } 117 } 118 } 119 else { 120 if (!last_block_n) { 121 for (m = 0; m < num_m; m++) { 122 _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 0*SIMD_WIDTH_FP32, _MM_LOADU_FP32(ptr_result + (size_t)m*handle->n + 0*SIMD_WIDTH_FP32)); 123 _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 1*SIMD_WIDTH_FP32, _MM_LOADU_FP32(ptr_result + (size_t)m*handle->n + 1*SIMD_WIDTH_FP32)); 124 _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 2*SIMD_WIDTH_FP32, _MM_LOADU_FP32(ptr_result + (size_t)m*handle->n + 2*SIMD_WIDTH_FP32)); 125 _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 3*SIMD_WIDTH_FP32, _MM_LOADU_FP32(ptr_result + (size_t)m*handle->n + 3*SIMD_WIDTH_FP32)); 126 _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 4*SIMD_WIDTH_FP32, _MM_LOADU_FP32(ptr_result + (size_t)m*handle->n + 4*SIMD_WIDTH_FP32)); 127 _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 5*SIMD_WIDTH_FP32, _MM_LOADU_FP32(ptr_result + (size_t)m*handle->n + 5*SIMD_WIDTH_FP32)); 128 } 129 } 130 else { 131 for (m = 0; m < num_m; m++) { 132 for (n = 0; n < num_full_regs; n += 2) { 133 _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + ((size_t)n) *SIMD_WIDTH_FP32, _MM_LOADU_FP32(ptr_result + (size_t)m*handle->n + ((size_t)n) *SIMD_WIDTH_FP32)); 134 _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + ((size_t)n+1)*SIMD_WIDTH_FP32, _MM_LOADU_FP32(ptr_result + (size_t)m*handle->n + ((size_t)n+1)*SIMD_WIDTH_FP32)); 135 } 136 for (n = last_n_start; n < num_n; n++) { 137 scratch_C[m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32+n] = ptr_result[m*handle->n+n]; 138 } 139 } 140 } 141 } 142 } 143 else { 144 SIMDTYPE_FP32 beta_v = _MM_SET1_FP32(*beta); 145 if ('T' == transc || 't' == transc) { 146 int num_m_simd = num_m / SIMD_WIDTH_FP32 * SIMD_WIDTH_FP32; 147 int num_n_simd = num_n / SIMD_WIDTH_FP32 * SIMD_WIDTH_FP32; 148 int m2; 149 150 ptr_result = c + (size_t)n_overall_start*handle->m + m_overall_start; 151 152 for (m = 0; m < num_m_simd; m += SIMD_WIDTH_FP32) { 153 for (n = 0; n < num_n_simd; n += SIMD_WIDTH_FP32) { 154 TRANSPOSE_SIMD_WIDTH_KERNEL(ptr_result + (size_t)n*handle->m + m, handle->m, scratch_C + (size_t)m*n_block_size + n, n_block_size); 155 _MM_STORE_FP32(scratch_C + (size_t)m*n_block_size + n, _MM_MUL_FP32(beta_v, _MM_LOADU_FP32(scratch_C + (size_t)m*n_block_size + n))); 156 _MM_STORE_FP32(scratch_C + (size_t)m*n_block_size + n + (size_t)n_block_size*1, _MM_MUL_FP32(beta_v, _MM_LOADU_FP32(scratch_C + (size_t)m*n_block_size + n + (size_t)n_block_size*1))); 157 _MM_STORE_FP32(scratch_C + (size_t)m*n_block_size + n + (size_t)n_block_size*2, _MM_MUL_FP32(beta_v, _MM_LOADU_FP32(scratch_C + (size_t)m*n_block_size + n + (size_t)n_block_size*2))); 158 _MM_STORE_FP32(scratch_C + (size_t)m*n_block_size + n + (size_t)n_block_size*3, _MM_MUL_FP32(beta_v, _MM_LOADU_FP32(scratch_C + (size_t)m*n_block_size + n + (size_t)n_block_size*3))); 159 _MM_STORE_FP32(scratch_C + (size_t)m*n_block_size + n + (size_t)n_block_size*4, _MM_MUL_FP32(beta_v, _MM_LOADU_FP32(scratch_C + (size_t)m*n_block_size + n + (size_t)n_block_size*4))); 160 _MM_STORE_FP32(scratch_C + (size_t)m*n_block_size + n + (size_t)n_block_size*5, _MM_MUL_FP32(beta_v, _MM_LOADU_FP32(scratch_C + (size_t)m*n_block_size + n + (size_t)n_block_size*5))); 161 _MM_STORE_FP32(scratch_C + (size_t)m*n_block_size + n + (size_t)n_block_size*6, _MM_MUL_FP32(beta_v, _MM_LOADU_FP32(scratch_C + (size_t)m*n_block_size + n + (size_t)n_block_size*6))); 162 _MM_STORE_FP32(scratch_C + (size_t)m*n_block_size + n + (size_t)n_block_size*7, _MM_MUL_FP32(beta_v, _MM_LOADU_FP32(scratch_C + (size_t)m*n_block_size + n + (size_t)n_block_size*7))); 163 } 164 /* Transpose a SIMD_WIDTH_FP32 * (num_n - num_n_simd) block of output space - input is of size (num_n - num_n_simd) * SIMD_WIDTH_FP32 */ 165 for (m2 = m; m2 < m + SIMD_WIDTH_FP32; m2++) { 166 for (n = num_n_simd; n < num_n; n++) { 167 scratch_C[m2*n_block_size + n] = (*beta)*ptr_result[n*handle->m + m2]; 168 } 169 } 170 } 171 /* Transpose a (num_m - num_m_simd) * num_n block of output space - input is of size num_n * (num_m - num_m_simd) */ 172 for (m = num_m_simd; m < num_m; m++) { 173 for (n = 0; n < num_n; n++) { 174 scratch_C[m*n_block_size + n] = (*beta)*ptr_result[n*handle->m + m]; 175 } 176 } 177 178 } 179 else { 180 if (!last_block_n) { 181 for (m = 0; m < num_m; m++) { 182 _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 0*SIMD_WIDTH_FP32, _MM_MUL_FP32(beta_v, _MM_LOADU_FP32(ptr_result + (size_t)m*handle->n + 0*SIMD_WIDTH_FP32))); 183 _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 1*SIMD_WIDTH_FP32, _MM_MUL_FP32(beta_v, _MM_LOADU_FP32(ptr_result + (size_t)m*handle->n + 1*SIMD_WIDTH_FP32))); 184 _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 2*SIMD_WIDTH_FP32, _MM_MUL_FP32(beta_v, _MM_LOADU_FP32(ptr_result + (size_t)m*handle->n + 2*SIMD_WIDTH_FP32))); 185 _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 3*SIMD_WIDTH_FP32, _MM_MUL_FP32(beta_v, _MM_LOADU_FP32(ptr_result + (size_t)m*handle->n + 3*SIMD_WIDTH_FP32))); 186 _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 4*SIMD_WIDTH_FP32, _MM_MUL_FP32(beta_v, _MM_LOADU_FP32(ptr_result + (size_t)m*handle->n + 4*SIMD_WIDTH_FP32))); 187 _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 5*SIMD_WIDTH_FP32, _MM_MUL_FP32(beta_v, _MM_LOADU_FP32(ptr_result + (size_t)m*handle->n + 5*SIMD_WIDTH_FP32))); 188 } 189 } 190 else { 191 for (m = 0; m < num_m; m++) { 192 for (n = 0; n < num_full_regs; n += 2) { 193 _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + ((size_t)n) *SIMD_WIDTH_FP32, _MM_MUL_FP32(beta_v, _MM_LOADU_FP32(ptr_result + (size_t)m*handle->n + ((size_t)n) *SIMD_WIDTH_FP32))); 194 _MM_STORE_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + ((size_t)n+1)*SIMD_WIDTH_FP32, _MM_MUL_FP32(beta_v, _MM_LOADU_FP32(ptr_result + (size_t)m*handle->n + ((size_t)n+1)*SIMD_WIDTH_FP32))); 195 } 196 for (n = last_n_start; n < num_n; n++) { 197 scratch_C[m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + n] = (*beta)*ptr_result[m*handle->n + n]; 198 } 199 } 200 } 201 } 202 } 203 204 for (kb = 0; kb < k_blocks; kb++) { 205 const uint16_t* LIBXSMM_RESTRICT ptr_dense; 206 float * LIBXSMM_RESTRICT scratch_C_base; 207 const float * LIBXSMM_RESTRICT scratch_B_base; 208 int block_A = kb * m_blocks + mb; 209 libxsmm_CSR_sparseslice slice = a_sparse[block_A]; 210 int m_local = 0; 211 212 k_overall_start = kb*k_block_size; 213 k_overall_end = (kb+1)*k_block_size; 214 num_k = (k_overall_end - k_overall_start); 215 216 /* Copy in b matrix */ 217 if ('T' == transb || 't' == transb) { 218 int num_k_simd = num_k / SIMD_WIDTH_FP32 * SIMD_WIDTH_FP32; 219 int num_n_simd = num_n / SIMD_WIDTH_FP32 * SIMD_WIDTH_FP32; 220 int k2; 221 222 ptr_dense = b + (size_t)n_overall_start*handle->k + k_overall_start; 223 224 for (k = 0; k < num_k_simd; k += SIMD_WIDTH_FP32) { 225 for (n = 0; n < num_n_simd; n += SIMD_WIDTH_FP32) { 226 TRANSPOSE_SIMD_WIDTH_KERNEL_BFLOAT16(ptr_dense + (size_t)n*handle->k + k, handle->k, scratch_B + (size_t)k*n_block_size + n, n_block_size); 227 } 228 /* Transpose a SIMD_WIDTH_FP32 * (num_n - num_n_simd) block of output space - input is of size (num_n - num_n_simd) * SIMD_WIDTH_FP32 */ 229 for (k2 = k; k2 < k + SIMD_WIDTH_FP32; k2++) { 230 for (n = num_n_simd; n < num_n; n++) { 231 uint16_t restmp = ptr_dense[n*handle->k + k2]; 232 union { int i; float f; } res; 233 res.i = restmp; 234 res.i <<= 16; 235 scratch_B[k2*n_block_size + n] = res.f; 236 } 237 } 238 } 239 /* Transpose a (num_m - num_m_simd) * num_n block of output space - input is of size num_n * (num_m - num_m_simd) */ 240 for (k = num_k_simd; k < num_k; k++) { 241 for (n = 0; n < num_n; n++) { 242 uint16_t restmp = ptr_dense[n*handle->k + k]; 243 union { int i; float f; } res; 244 res.i = restmp; 245 res.i <<= 16; 246 scratch_B[k*n_block_size + n] = res.f; 247 } 248 } 249 } else { 250 ptr_dense = b + (size_t)k_overall_start*handle->n + n_overall_start; 251 if (!last_block_n) { 252 for (k = 0; k < num_k; k++) { 253 SIMDTYPE_INT32 vload_0 = _MM_LOADU_INT32((const SIMDTYPE_INT32*)(ptr_dense + (size_t)k*handle->n + 2*0*SIMD_WIDTH_FP32)); 254 SIMDTYPE_INT32 vload_1, vload_2; 255 SIMDTYPE_FP32 v1_0, v2_0; 256 SIMDTYPE_FP32 v1_1, v2_1; 257 SIMDTYPE_FP32 v1_2, v2_2; 258 EXPAND_BFLOAT16(vload_0, v1_0, v2_0); 259 _MM_STORE_FP32(scratch_B + (size_t)k*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 2*0*SIMD_WIDTH_FP32, v1_0); 260 _MM_STORE_FP32(scratch_B + (size_t)k*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + (2*0+1)*SIMD_WIDTH_FP32, v2_0); 261 vload_1 = _MM_LOADU_INT32((const SIMDTYPE_INT32 *)(ptr_dense + (size_t)k*handle->n + 2*1*SIMD_WIDTH_FP32)); 262 EXPAND_BFLOAT16(vload_1, v1_1, v2_1); 263 _MM_STORE_FP32(scratch_B + (size_t)k*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 2*1*SIMD_WIDTH_FP32, v1_1); 264 _MM_STORE_FP32(scratch_B + (size_t)k*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + (2*1+1)*SIMD_WIDTH_FP32, v2_1); 265 vload_2 = _MM_LOADU_INT32((const SIMDTYPE_INT32 *)(ptr_dense + (size_t)k*handle->n + 2*2*SIMD_WIDTH_FP32)); 266 EXPAND_BFLOAT16(vload_2, v1_2, v2_2); 267 _MM_STORE_FP32(scratch_B + (size_t)k*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 2*2*SIMD_WIDTH_FP32, v1_2); 268 _MM_STORE_FP32(scratch_B + (size_t)k*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + (2*2+1)*SIMD_WIDTH_FP32, v2_2); 269 } 270 } else { 271 for (k = 0; k < num_k; k++) { 272 for (n = 0; n < num_full_regs; n += 2) { 273 SIMDTYPE_INT32 vload_0 = _MM_LOADU_INT32((const SIMDTYPE_INT32*)(ptr_dense + (size_t)k*handle->n + (size_t)n*SIMD_WIDTH_FP32)); 274 SIMDTYPE_FP32 v1_0, v2_0; 275 EXPAND_BFLOAT16(vload_0, v1_0, v2_0); 276 _MM_STORE_FP32(scratch_B + (size_t)k*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + ((size_t)n) *SIMD_WIDTH_FP32, v1_0); 277 _MM_STORE_FP32(scratch_B + (size_t)k*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + ((size_t)n+1)*SIMD_WIDTH_FP32, v2_0); 278 } 279 for (n = last_n_start; n < num_n; n++) { 280 uint16_t restmp = ptr_dense[k*handle->n + n]; 281 union { int i; float f; } res; 282 res.i = restmp; 283 res.i <<= 16; 284 { 285 scratch_B[k*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + n] = res.f; 286 } 287 } 288 } 289 } 290 } 291 292 scratch_C_base = scratch_C - (size_t)m_overall_start*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32; 293 scratch_B_base = scratch_B; /* - (size_t)k_overall_start*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32; */ 294 295 for (m = m_overall_start; m < m_overall_start + num_m_aligned; m += 2, m_local += 2) { 296 int start_j, end_j, end_j_2, num_j, num_j_2; 297 const uint16_t *LIBXSMM_RESTRICT sp_c_ptr_base; 298 const uint16_t *LIBXSMM_RESTRICT sp_c_ptr_base_2; 299 const float *LIBXSMM_RESTRICT sp_v_ptr_base; 300 const float *LIBXSMM_RESTRICT sp_v_ptr_base_2; 301 float *const LIBXSMM_RESTRICT result_m_index = scratch_C_base + ((size_t)m) *LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32; 302 float *const LIBXSMM_RESTRICT result_m_index_2 = scratch_C_base + ((size_t)m+1)*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32; 303 304 if (m_local >= m_block_size) { block_A++; slice = a_sparse[block_A]; m_local = 0; } 305 306 start_j = slice.rowidx[m_local]; 307 end_j = slice.rowidx[m_local + 1]; 308 end_j_2 = slice.rowidx[m_local + 2]; 309 num_j = (end_j - start_j); 310 num_j_2 = (end_j_2 - end_j); 311 sp_c_ptr_base = slice.colidx + start_j; 312 sp_c_ptr_base_2 = slice.colidx + end_j; 313 sp_v_ptr_base = (float *)(slice.values) + start_j; 314 sp_v_ptr_base_2 = (float *)(slice.values) + end_j; 315 316 if (!last_block_n) 317 { 318 int64_t j = 0, j2 = 0; 319 sum[0] = _MM_LOAD_FP32(result_m_index + 0*SIMD_WIDTH_FP32); 320 sum[0+LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_LOAD_FP32(result_m_index_2 + 0*SIMD_WIDTH_FP32); 321 sum[1] = _MM_LOAD_FP32(result_m_index + 1*SIMD_WIDTH_FP32); 322 sum[1+LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_LOAD_FP32(result_m_index_2 + 1*SIMD_WIDTH_FP32); 323 sum[2] = _MM_LOAD_FP32(result_m_index + 2*SIMD_WIDTH_FP32); 324 sum[2+LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_LOAD_FP32(result_m_index_2 + 2*SIMD_WIDTH_FP32); 325 sum[3] = _MM_LOAD_FP32(result_m_index + 3*SIMD_WIDTH_FP32); 326 sum[3+LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_LOAD_FP32(result_m_index_2 + 3*SIMD_WIDTH_FP32); 327 sum[4] = _MM_LOAD_FP32(result_m_index + 4*SIMD_WIDTH_FP32); 328 sum[4+LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_LOAD_FP32(result_m_index_2 + 4*SIMD_WIDTH_FP32); 329 sum[5] = _MM_LOAD_FP32(result_m_index + 5*SIMD_WIDTH_FP32); 330 sum[5+LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_LOAD_FP32(result_m_index_2 + 5*SIMD_WIDTH_FP32); 331 for (; j < num_j && j2 < num_j_2; j++, j2++) { 332 const float *const LIBXSMM_RESTRICT sp_col_dense_index = scratch_B_base + (size_t)sp_c_ptr_base[j] *LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32; 333 const float *const LIBXSMM_RESTRICT sp_col_dense_index_2 = scratch_B_base + (size_t)sp_c_ptr_base_2[j2]*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32; 334 SIMDTYPE_FP32 v_v = _MM_SET1_FP32(sp_v_ptr_base[j]); 335 SIMDTYPE_FP32 v_v_2 = _MM_SET1_FP32(sp_v_ptr_base_2[j2]); 336 sum[0] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 0*SIMD_WIDTH_FP32), sum[0]); 337 sum[0 + LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_FMADD_FP32(v_v_2, _MM_LOAD_FP32(sp_col_dense_index_2 + 0*SIMD_WIDTH_FP32), sum[0+LIBXSMM_SPMDM_COMPUTE_NREGS]); 338 sum[1] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 1*SIMD_WIDTH_FP32), sum[1]); 339 sum[1 + LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_FMADD_FP32(v_v_2, _MM_LOAD_FP32(sp_col_dense_index_2 + 1*SIMD_WIDTH_FP32), sum[1+LIBXSMM_SPMDM_COMPUTE_NREGS]); 340 sum[2] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 2*SIMD_WIDTH_FP32), sum[2]); 341 sum[2 + LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_FMADD_FP32(v_v_2, _MM_LOAD_FP32(sp_col_dense_index_2 + 2*SIMD_WIDTH_FP32), sum[2+LIBXSMM_SPMDM_COMPUTE_NREGS]); 342 sum[3] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 3*SIMD_WIDTH_FP32), sum[3]); 343 sum[3 + LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_FMADD_FP32(v_v_2, _MM_LOAD_FP32(sp_col_dense_index_2 + 3*SIMD_WIDTH_FP32), sum[3+LIBXSMM_SPMDM_COMPUTE_NREGS]); 344 sum[4] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 4*SIMD_WIDTH_FP32), sum[4]); 345 sum[4 + LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_FMADD_FP32(v_v_2, _MM_LOAD_FP32(sp_col_dense_index_2 + 4*SIMD_WIDTH_FP32), sum[4+LIBXSMM_SPMDM_COMPUTE_NREGS]); 346 sum[5] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 5*SIMD_WIDTH_FP32), sum[5]); 347 sum[5 + LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_FMADD_FP32(v_v_2, _MM_LOAD_FP32(sp_col_dense_index_2 + 5*SIMD_WIDTH_FP32), sum[5+LIBXSMM_SPMDM_COMPUTE_NREGS]); 348 } 349 for (; j < num_j; j++) { 350 const float *const LIBXSMM_RESTRICT sp_col_dense_index = scratch_B_base + (size_t)sp_c_ptr_base[j]*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32; 351 SIMDTYPE_FP32 v_v = _MM_SET1_FP32(sp_v_ptr_base[j]); 352 sum[0] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 0*SIMD_WIDTH_FP32), sum[0]); 353 sum[1] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 1*SIMD_WIDTH_FP32), sum[1]); 354 sum[2] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 2*SIMD_WIDTH_FP32), sum[2]); 355 sum[3] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 3*SIMD_WIDTH_FP32), sum[3]); 356 sum[4] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 4*SIMD_WIDTH_FP32), sum[4]); 357 sum[5] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 5*SIMD_WIDTH_FP32), sum[5]); 358 } 359 for (; j2 < num_j_2; j2++) { 360 const float *const LIBXSMM_RESTRICT sp_col_dense_index_2 = scratch_B_base + (size_t)sp_c_ptr_base_2[j2]*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32; 361 SIMDTYPE_FP32 v_v_2 = _MM_SET1_FP32(sp_v_ptr_base_2[j2]); 362 sum[0 + LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_FMADD_FP32(v_v_2, _MM_LOAD_FP32(sp_col_dense_index_2 + 0*SIMD_WIDTH_FP32), sum[0+LIBXSMM_SPMDM_COMPUTE_NREGS]); 363 sum[1 + LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_FMADD_FP32(v_v_2, _MM_LOAD_FP32(sp_col_dense_index_2 + 1*SIMD_WIDTH_FP32), sum[1+LIBXSMM_SPMDM_COMPUTE_NREGS]); 364 sum[2 + LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_FMADD_FP32(v_v_2, _MM_LOAD_FP32(sp_col_dense_index_2 + 2*SIMD_WIDTH_FP32), sum[2+LIBXSMM_SPMDM_COMPUTE_NREGS]); 365 sum[3 + LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_FMADD_FP32(v_v_2, _MM_LOAD_FP32(sp_col_dense_index_2 + 3*SIMD_WIDTH_FP32), sum[3+LIBXSMM_SPMDM_COMPUTE_NREGS]); 366 sum[4 + LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_FMADD_FP32(v_v_2, _MM_LOAD_FP32(sp_col_dense_index_2 + 4*SIMD_WIDTH_FP32), sum[4+LIBXSMM_SPMDM_COMPUTE_NREGS]); 367 sum[5 + LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_FMADD_FP32(v_v_2, _MM_LOAD_FP32(sp_col_dense_index_2 + 5*SIMD_WIDTH_FP32), sum[5+LIBXSMM_SPMDM_COMPUTE_NREGS]); 368 } 369 _MM_STORE_FP32(result_m_index + 0*SIMD_WIDTH_FP32, sum[0]); 370 _MM_STORE_FP32(result_m_index_2 + 0*SIMD_WIDTH_FP32, sum[0+LIBXSMM_SPMDM_COMPUTE_NREGS]); 371 _MM_STORE_FP32(result_m_index + 1*SIMD_WIDTH_FP32, sum[1]); 372 _MM_STORE_FP32(result_m_index_2 + 1*SIMD_WIDTH_FP32, sum[1+LIBXSMM_SPMDM_COMPUTE_NREGS]); 373 _MM_STORE_FP32(result_m_index + 2*SIMD_WIDTH_FP32, sum[2]); 374 _MM_STORE_FP32(result_m_index_2 + 2*SIMD_WIDTH_FP32, sum[2+LIBXSMM_SPMDM_COMPUTE_NREGS]); 375 _MM_STORE_FP32(result_m_index + 3*SIMD_WIDTH_FP32, sum[3]); 376 _MM_STORE_FP32(result_m_index_2 + 3*SIMD_WIDTH_FP32, sum[3+LIBXSMM_SPMDM_COMPUTE_NREGS]); 377 _MM_STORE_FP32(result_m_index + 4*SIMD_WIDTH_FP32, sum[4]); 378 _MM_STORE_FP32(result_m_index_2 + 4*SIMD_WIDTH_FP32, sum[4+LIBXSMM_SPMDM_COMPUTE_NREGS]); 379 _MM_STORE_FP32(result_m_index + 5*SIMD_WIDTH_FP32, sum[5]); 380 _MM_STORE_FP32(result_m_index_2 + 5*SIMD_WIDTH_FP32, sum[5+LIBXSMM_SPMDM_COMPUTE_NREGS]); 381 } 382 else { 383 int64_t j = 0, j2 = 0; 384 for (n = 0; n < num_full_regs; n += 2) { 385 sum[n] = _MM_SETZERO_FP32(); 386 sum[n+LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_SETZERO_FP32(); 387 sum[n+1] = _MM_SETZERO_FP32(); 388 sum[n+1+LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_SETZERO_FP32(); 389 } 390 for (; j < num_j && j2 < num_j_2; j++, j2++) { 391 const float *const LIBXSMM_RESTRICT sp_col_dense_index = scratch_B_base + (size_t)sp_c_ptr_base[j] *LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32; 392 const float *const LIBXSMM_RESTRICT sp_col_dense_index_2 = scratch_B_base + (size_t)sp_c_ptr_base_2[j2]*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32; 393 SIMDTYPE_FP32 v_v = _MM_SET1_FP32(sp_v_ptr_base[j]); 394 SIMDTYPE_FP32 v_v_2 = _MM_SET1_FP32(sp_v_ptr_base_2[j2]); 395 for (n = 0; n < num_full_regs; n += 2) { 396 sum[n] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + (size_t)n*SIMD_WIDTH_FP32), sum[n]); 397 sum[n+LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_FMADD_FP32(v_v_2, _MM_LOAD_FP32(sp_col_dense_index_2 + (size_t)n*SIMD_WIDTH_FP32), sum[n+LIBXSMM_SPMDM_COMPUTE_NREGS]); 398 sum[n+1] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + ((size_t)n+1)*SIMD_WIDTH_FP32), sum[n+1]); 399 sum[n+1 + LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_FMADD_FP32(v_v_2, _MM_LOAD_FP32(sp_col_dense_index_2 + ((size_t)n+1)*SIMD_WIDTH_FP32), sum[n+1+LIBXSMM_SPMDM_COMPUTE_NREGS]); 400 } 401 { 402 float v_v_f = sp_v_ptr_base[j]; 403 float v_v_f_2 = sp_v_ptr_base_2[j2]; 404 for (n = last_n_start; n < num_n; n++) { 405 result_m_index[n] += sp_col_dense_index[n]*v_v_f; 406 result_m_index_2[n] += sp_col_dense_index_2[n]*v_v_f_2; 407 } 408 } 409 } 410 for (; j < num_j; j++) { 411 const float *const LIBXSMM_RESTRICT sp_col_dense_index = scratch_B_base + (size_t)sp_c_ptr_base[j]*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32; 412 SIMDTYPE_FP32 v_v = _MM_SET1_FP32(sp_v_ptr_base[j]); 413 for (n = 0; n < num_full_regs; n += 2) { 414 sum[n] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + ((size_t)n) *SIMD_WIDTH_FP32), sum[n]); 415 sum[n+1] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + ((size_t)n+1)*SIMD_WIDTH_FP32), sum[n+1]); 416 } 417 { 418 float v_v_f = sp_v_ptr_base[j]; 419 for (n = last_n_start; n < num_n; n++) { 420 result_m_index[n] += sp_col_dense_index[n]*v_v_f; 421 } 422 } 423 } 424 for (; j2 < num_j_2; j2++) { 425 const float *const LIBXSMM_RESTRICT sp_col_dense_index_2 = scratch_B_base + (size_t)sp_c_ptr_base_2[j2]*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32; 426 SIMDTYPE_FP32 v_v_2 = _MM_SET1_FP32(sp_v_ptr_base_2[j2]); 427 for (n = 0; n < num_full_regs; n += 2) { 428 sum[n + LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_FMADD_FP32(v_v_2, _MM_LOAD_FP32(sp_col_dense_index_2 + ((size_t)n) *SIMD_WIDTH_FP32), sum[n+LIBXSMM_SPMDM_COMPUTE_NREGS]); 429 sum[n+1 + LIBXSMM_SPMDM_COMPUTE_NREGS] = _MM_FMADD_FP32(v_v_2, _MM_LOAD_FP32(sp_col_dense_index_2 + ((size_t)n+1)*SIMD_WIDTH_FP32), sum[n+1+LIBXSMM_SPMDM_COMPUTE_NREGS]); 430 } 431 { 432 float v_v_f_2 = sp_v_ptr_base_2[j2]; 433 for (n = last_n_start; n < num_n; n++) { 434 result_m_index_2[n] += sp_col_dense_index_2[n]*v_v_f_2; 435 } 436 } 437 } 438 for (n = 0; n < num_full_regs; n += 2) { 439 _MM_STORE_FP32(result_m_index + ((size_t)n) *SIMD_WIDTH_FP32, _MM_ADD_FP32(sum[n], _MM_LOAD_FP32(result_m_index + (size_t)n*SIMD_WIDTH_FP32))); 440 _MM_STORE_FP32(result_m_index_2 + ((size_t)n) *SIMD_WIDTH_FP32, _MM_ADD_FP32(sum[n+LIBXSMM_SPMDM_COMPUTE_NREGS], _MM_LOAD_FP32(result_m_index_2 + (size_t)n*SIMD_WIDTH_FP32))); 441 _MM_STORE_FP32(result_m_index + ((size_t)n+1)*SIMD_WIDTH_FP32, _MM_ADD_FP32(sum[n+1], _MM_LOAD_FP32(result_m_index + ((size_t)n+1)*SIMD_WIDTH_FP32))); 442 _MM_STORE_FP32(result_m_index_2 + ((size_t)n+1)*SIMD_WIDTH_FP32, _MM_ADD_FP32(sum[n+1+LIBXSMM_SPMDM_COMPUTE_NREGS], _MM_LOAD_FP32(result_m_index_2 + ((size_t)n+1)*SIMD_WIDTH_FP32))); 443 } 444 } 445 } 446 for (m = m_overall_start + num_m_aligned; m < m_overall_end; m++, m_local++) { 447 int start_j, end_j, num_j; 448 const uint16_t* LIBXSMM_RESTRICT sp_c_ptr_base; 449 const float* LIBXSMM_RESTRICT sp_v_ptr_base; 450 float* LIBXSMM_RESTRICT result_m_index; 451 452 if (m_local >= m_block_size) { block_A++; slice = a_sparse[block_A]; m_local = 0; } 453 454 start_j = slice.rowidx[m_local]; 455 end_j = slice.rowidx[m_local + 1]; 456 num_j = (end_j - start_j); 457 sp_c_ptr_base = slice.colidx + start_j; 458 sp_v_ptr_base = slice.values + start_j; 459 result_m_index = scratch_C_base + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32; 460 461 if (!last_block_n) { 462 int64_t j = 0; 463 sum[0] = _MM_LOAD_FP32(result_m_index + 0*SIMD_WIDTH_FP32); 464 sum[1] = _MM_LOAD_FP32(result_m_index + 1*SIMD_WIDTH_FP32); 465 sum[2] = _MM_LOAD_FP32(result_m_index + 2*SIMD_WIDTH_FP32); 466 sum[3] = _MM_LOAD_FP32(result_m_index + 3*SIMD_WIDTH_FP32); 467 sum[4] = _MM_LOAD_FP32(result_m_index + 4*SIMD_WIDTH_FP32); 468 sum[5] = _MM_LOAD_FP32(result_m_index + 5*SIMD_WIDTH_FP32); 469 for (; j < num_j; j++) { 470 const float *const LIBXSMM_RESTRICT sp_col_dense_index = scratch_B_base + (size_t)sp_c_ptr_base[j]*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32; 471 SIMDTYPE_FP32 v_v = _MM_SET1_FP32(sp_v_ptr_base[j]); 472 sum[0] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 0*SIMD_WIDTH_FP32), sum[0]); 473 sum[1] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 1*SIMD_WIDTH_FP32), sum[1]); 474 sum[2] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 2*SIMD_WIDTH_FP32), sum[2]); 475 sum[3] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 3*SIMD_WIDTH_FP32), sum[3]); 476 sum[4] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 4*SIMD_WIDTH_FP32), sum[4]); 477 sum[5] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + 5*SIMD_WIDTH_FP32), sum[5]); 478 } 479 _MM_STORE_FP32(result_m_index + 0*SIMD_WIDTH_FP32, sum[0]); 480 _MM_STORE_FP32(result_m_index + 1*SIMD_WIDTH_FP32, sum[1]); 481 _MM_STORE_FP32(result_m_index + 2*SIMD_WIDTH_FP32, sum[2]); 482 _MM_STORE_FP32(result_m_index + 3*SIMD_WIDTH_FP32, sum[3]); 483 _MM_STORE_FP32(result_m_index + 4*SIMD_WIDTH_FP32, sum[4]); 484 _MM_STORE_FP32(result_m_index + 5*SIMD_WIDTH_FP32, sum[5]); 485 } 486 else { 487 int64_t j = 0; 488 for (n = 0; n < num_full_regs; n += 2) { 489 sum[n] = _MM_SETZERO_FP32(); 490 sum[n+1] = _MM_SETZERO_FP32(); 491 } 492 for (; j < num_j; j++) { 493 const float *const LIBXSMM_RESTRICT sp_col_dense_index = scratch_B_base + (size_t)sp_c_ptr_base[j]*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32; 494 SIMDTYPE_FP32 v_v = _MM_SET1_FP32(sp_v_ptr_base[j]); 495 for (n = 0; n < num_full_regs; n += 2) { 496 sum[n] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + ((size_t)n) *SIMD_WIDTH_FP32), sum[n]); 497 sum[n+1] = _MM_FMADD_FP32(v_v, _MM_LOAD_FP32(sp_col_dense_index + ((size_t)n+1)*SIMD_WIDTH_FP32), sum[n+1]); 498 } 499 { 500 float v_v_f = sp_v_ptr_base[j]; 501 for (n = last_n_start; n < num_n; n++) { 502 result_m_index[n] += sp_col_dense_index[n]*v_v_f; 503 } 504 } 505 } 506 for (n = 0; n < num_full_regs; n += 2) { 507 _MM_STORE_FP32(result_m_index + ((size_t)n) *SIMD_WIDTH_FP32, _MM_ADD_FP32(sum[n], _MM_LOAD_FP32(result_m_index + ((size_t)n) *SIMD_WIDTH_FP32))); 508 _MM_STORE_FP32(result_m_index + ((size_t)n+1)*SIMD_WIDTH_FP32, _MM_ADD_FP32(sum[n+1], _MM_LOAD_FP32(result_m_index + ((size_t)n+1)*SIMD_WIDTH_FP32))); 509 } 510 511 } 512 } 513 } /* kb */ 514 515 /* Copy out c matrix */ 516 if ('T' == transc || 't' == transc) { 517 int num_m_simd = num_m / SIMD_WIDTH_FP32 * SIMD_WIDTH_FP32; 518 int num_n_simd = num_n / SIMD_WIDTH_FP32 * SIMD_WIDTH_FP32; 519 int n2; 520 521 ptr_result = c + (size_t)n_overall_start*handle->m + m_overall_start; 522 for (n = 0; n < num_n_simd; n += SIMD_WIDTH_FP32) { 523 for (m = 0; m < num_m_simd; m += SIMD_WIDTH_FP32) { 524 TRANSPOSE_SIMD_WIDTH_KERNEL(scratch_C + (size_t)m*n_block_size + n, n_block_size, ptr_result + (size_t)n*handle->m + m, handle->m); 525 } 526 /* Transpose a SIMD_WIDTH_FP32 * (num_m - num_m_simd) block of output space - input is of size (num_m - num_m_simd) * SIMD_WIDTH_FP32 */ 527 for (n2 = n; n2 < n + SIMD_WIDTH_FP32; n2++) { 528 for (m = num_m_simd; m < num_m; m++) { 529 ptr_result[n2*handle->m + m] = scratch_C[m*n_block_size + n2]; 530 } 531 } 532 } 533 /* Transpose a (num_n - num_n_simd) * num_m block of output space - input is of size num_m * (num_n - num_n_simd) */ 534 for (n = num_n_simd; n < num_n; n++) { 535 for (m = 0; m < num_m; m++) { 536 ptr_result[n*handle->m + m] = scratch_C[m*n_block_size + n]; 537 } 538 } 539 } 540 else { 541 if (!last_block_n) { 542 for (m = 0; m < num_m; m++) { 543 _MM_STOREU_FP32(ptr_result + (size_t)m*handle->n + 0*SIMD_WIDTH_FP32, _MM_LOAD_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 0*SIMD_WIDTH_FP32)); 544 _MM_STOREU_FP32(ptr_result + (size_t)m*handle->n + 1*SIMD_WIDTH_FP32, _MM_LOAD_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 1*SIMD_WIDTH_FP32)); 545 _MM_STOREU_FP32(ptr_result + (size_t)m*handle->n + 2*SIMD_WIDTH_FP32, _MM_LOAD_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 2*SIMD_WIDTH_FP32)); 546 _MM_STOREU_FP32(ptr_result + (size_t)m*handle->n + 3*SIMD_WIDTH_FP32, _MM_LOAD_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 3*SIMD_WIDTH_FP32)); 547 _MM_STOREU_FP32(ptr_result + (size_t)m*handle->n + 4*SIMD_WIDTH_FP32, _MM_LOAD_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 4*SIMD_WIDTH_FP32)); 548 _MM_STOREU_FP32(ptr_result + (size_t)m*handle->n + 5*SIMD_WIDTH_FP32, _MM_LOAD_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + 5*SIMD_WIDTH_FP32)); 549 } 550 } 551 else { 552 for (m = 0; m < num_m; m++) { 553 for (n = 0; n < num_full_regs; n += 2) { 554 _MM_STOREU_FP32(ptr_result + (size_t)m*handle->n + ((size_t)n) *SIMD_WIDTH_FP32, _MM_LOAD_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + ((size_t)n) *SIMD_WIDTH_FP32)); 555 _MM_STOREU_FP32(ptr_result + (size_t)m*handle->n + ((size_t)n+1)*SIMD_WIDTH_FP32, _MM_LOAD_FP32(scratch_C + (size_t)m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + ((size_t)n+1)*SIMD_WIDTH_FP32)); 556 } 557 for (n = last_n_start; n < num_n; n++) { 558 ptr_result[m*handle->n + n] = scratch_C[m*LIBXSMM_SPMDM_COMPUTE_NREGS*SIMD_WIDTH_FP32 + n]; 559 } 560 } 561 } 562 } 563 564 #undef LIBXSMM_SPMDM_COMPUTE_NREGS 565