// BUG1989 is pleased to support the open source community by supporting ncnn available. // // Copyright (C) 2019 BUG1989. All rights reserved. // // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except // in compliance with the License. You may obtain a copy of the License at // // https://opensource.org/licenses/BSD-3-Clause // // Unless required by applicable law or agreed to in writing, software distributed // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. #if __AVX__ static void conv_im2col_sgemm_transform_kernel_sse(const Mat& _kernel, Mat& kernel_tm, int inch, int outch, int kernel_size) { const float* kernel = _kernel; // kernel memory packed 8 x 8 kernel_tm.create(8 * kernel_size, inch, outch / 8 + (outch % 8) / 4 + outch % 4); int nn_outch = 0; int remain_outch_start = 0; nn_outch = outch >> 3; remain_outch_start = nn_outch << 3; for (int pp = 0; pp < nn_outch; pp++) { int p = pp * 8; const float* k0 = kernel + (p + 0) * inch * kernel_size; const float* k1 = kernel + (p + 1) * inch * kernel_size; const float* k2 = kernel + (p + 2) * inch * kernel_size; const float* k3 = kernel + (p + 3) * inch * kernel_size; const float* k4 = kernel + (p + 4) * inch * kernel_size; const float* k5 = kernel + (p + 5) * inch * kernel_size; const float* k6 = kernel + (p + 6) * inch * kernel_size; const float* k7 = kernel + (p + 7) * inch * kernel_size; float* ktmp = kernel_tm.channel(p / 8); for (int q = 0; q < inch * kernel_size; q++) { ktmp[0] = k0[0]; ktmp[1] = k1[0]; ktmp[2] = k2[0]; ktmp[3] = k3[0]; ktmp[4] = k4[0]; ktmp[5] = k5[0]; ktmp[6] = k6[0]; ktmp[7] = k7[0]; ktmp += 8; k0 += 1; k1 += 1; k2 += 1; k3 += 1; k4 += 1; k5 += 1; k6 += 1; k7 += 1; } } nn_outch = (outch - remain_outch_start) >> 2; for (int pp = 0; pp < nn_outch; pp++) { int p = remain_outch_start + pp * 4; const float* k0 = kernel + (p + 0) * inch * kernel_size; const float* k1 = kernel + (p + 1) * inch * kernel_size; const float* k2 = kernel + (p + 2) * inch * kernel_size; const float* k3 = kernel + (p + 3) * inch * kernel_size; float* ktmp = kernel_tm.channel(p / 8 + (p % 8) / 4); for (int q = 0; q < inch * kernel_size; q++) { ktmp[0] = k0[0]; ktmp[1] = k1[0]; ktmp[2] = k2[0]; ktmp[3] = k3[0]; ktmp += 4; k0 += 1; k1 += 1; k2 += 1; k3 += 1; } } remain_outch_start += nn_outch << 2; for (int p = remain_outch_start; p < outch; p++) { const float* k0 = kernel + (p + 0) * inch * kernel_size; float* ktmp = kernel_tm.channel(p / 8 + (p % 8) / 4 + p % 4); for (int q = 0; q < inch * kernel_size; q++) { ktmp[0] = k0[0]; ktmp++; k0++; } } } static void conv_im2col_sgemm_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& _bias, const int kernel_w, const int kernel_h, const int stride_w, const int stride_h, const Option& opt) { int w = bottom_blob.w; int inch = bottom_blob.c; size_t elemsize = bottom_blob.elemsize; int outw = top_blob.w; int outh = top_blob.h; int outch = top_blob.c; const float* bias = _bias; // im2col Mat bottom_im2col(outw * outh, kernel_h * kernel_w * inch, elemsize, opt.workspace_allocator); { const int stride = kernel_h * kernel_w * outw * outh; float* ret = (float*)bottom_im2col; #pragma omp parallel for num_threads(opt.num_threads) for (int p = 0; p < inch; p++) { const float* input = bottom_blob.channel(p); int retID = stride * p; for (int u = 0; u < kernel_h; u++) { for (int v = 0; v < kernel_w; v++) { for (int i = 0; i < outh; i++) { for (int j = 0; j < outw; j++) { int row = u + i * stride_h; int col = v + j * stride_w; int index = row * w + col; ret[retID] = input[index]; retID++; } } } } } } int kernel_size = kernel_w * kernel_h; int out_size = outw * outh; // bottom_im2col memory packed 8 x 8 Mat bottom_tm(8 * kernel_size, inch, out_size / 8 + out_size % 8, elemsize, opt.workspace_allocator); { int nn_size = out_size >> 3; int remain_size_start = nn_size << 3; #pragma omp parallel for num_threads(opt.num_threads) for (int ii = 0; ii < nn_size; ii++) { int i = ii * 8; const float* img0 = bottom_im2col.channel(0); img0 += i; float* tmpptr = bottom_tm.channel(i / 8); for (int q = 0; q < inch * kernel_size; q++) { #if __AVX__ _mm256_storeu_ps(tmpptr, _mm256_loadu_ps(img0)); #else tmpptr[0] = img0[0]; tmpptr[1] = img0[1]; tmpptr[2] = img0[2]; tmpptr[3] = img0[3]; tmpptr[4] = img0[4]; tmpptr[5] = img0[5]; tmpptr[6] = img0[6]; tmpptr[7] = img0[7]; #endif // __SSE__ tmpptr += 8; img0 += out_size; } } #pragma omp parallel for num_threads(opt.num_threads) for (int i = remain_size_start; i < out_size; i++) { const float* img0 = bottom_im2col.channel(0); img0 += i; float* tmpptr = bottom_tm.channel(i / 8 + i % 8); for (int q = 0; q < inch * kernel_size; q++) { tmpptr[0] = img0[0]; tmpptr += 1; img0 += out_size; } } } // sgemm(int M, int N, int L, float* A, float* B, float* C) { //int M = outch; // outch int N = outw * outh; // outsize or out stride int L = kernel_w * kernel_h * inch; // ksize * inch int nn_outch = 0; int remain_outch_start = 0; nn_outch = outch >> 3; remain_outch_start = nn_outch << 3; #pragma omp parallel for num_threads(opt.num_threads) for (int pp = 0; pp < nn_outch; pp++) { int i = pp * 8; float* output0 = top_blob.channel(i); float* output1 = top_blob.channel(i + 1); float* output2 = top_blob.channel(i + 2); float* output3 = top_blob.channel(i + 3); float* output4 = top_blob.channel(i + 4); float* output5 = top_blob.channel(i + 5); float* output6 = top_blob.channel(i + 6); float* output7 = top_blob.channel(i + 7); const float zeros[8] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; const float* biasptr = bias ? bias + i : zeros; int j = 0; for (; j + 7 < N; j = j + 8) { const float* vb = bottom_tm.channel(j / 8); const float* va = kernel_tm.channel(i / 8); #if __AVX__ __m256 _sum0 = _mm256_broadcast_ss(biasptr); __m256 _sum1 = _mm256_broadcast_ss(biasptr + 1); __m256 _sum2 = _mm256_broadcast_ss(biasptr + 2); __m256 _sum3 = _mm256_broadcast_ss(biasptr + 3); __m256 _sum4 = _mm256_broadcast_ss(biasptr + 4); __m256 _sum5 = _mm256_broadcast_ss(biasptr + 5); __m256 _sum6 = _mm256_broadcast_ss(biasptr + 6); __m256 _sum7 = _mm256_broadcast_ss(biasptr + 7); int k = 0; for (; k + 3 < L; k = k + 4) { // k0 __m256 _va0 = _mm256_broadcast_ss(va); __m256 _va1 = _mm256_broadcast_ss(va + 1); __m256 _va2 = _mm256_broadcast_ss(va + 2); __m256 _va3 = _mm256_broadcast_ss(va + 3); __m256 _vb0 = _mm256_loadu_ps(vb); __m256 _vb1 = _mm256_loadu_ps(vb + 8); __m256 _vb2 = _mm256_loadu_ps(vb + 16); __m256 _vb3 = _mm256_loadu_ps(vb + 24); _sum0 = _mm256_fmadd_ps(_vb0, _va0, _sum0); // sum0 = (a00-a07) * k00 _sum1 = _mm256_fmadd_ps(_vb0, _va1, _sum1); // sum1 = (a00-a07) * k10 _sum2 = _mm256_fmadd_ps(_vb0, _va2, _sum2); // sum2 = (a00-a07) * k20 _sum3 = _mm256_fmadd_ps(_vb0, _va3, _sum3); // sum3 = (a00-a07) * k30 _va0 = _mm256_broadcast_ss(va + 4); _va1 = _mm256_broadcast_ss(va + 5); _va2 = _mm256_broadcast_ss(va + 6); _va3 = _mm256_broadcast_ss(va + 7); _sum4 = _mm256_fmadd_ps(_vb0, _va0, _sum4); // sum4 = (a00-a07) * k40 _sum5 = _mm256_fmadd_ps(_vb0, _va1, _sum5); // sum5 = (a00-a07) * k50 _sum6 = _mm256_fmadd_ps(_vb0, _va2, _sum6); // sum6 = (a00-a07) * k60 _sum7 = _mm256_fmadd_ps(_vb0, _va3, _sum7); // sum7 = (a00-a07) * k70 va += 8; // k1 _va0 = _mm256_broadcast_ss(va); _va1 = _mm256_broadcast_ss(va + 1); _va2 = _mm256_broadcast_ss(va + 2); _va3 = _mm256_broadcast_ss(va + 3); _sum0 = _mm256_fmadd_ps(_vb1, _va0, _sum0); // sum0 += (a10-a17) * k01 _sum1 = _mm256_fmadd_ps(_vb1, _va1, _sum1); // sum1 += (a10-a17) * k11 _sum2 = _mm256_fmadd_ps(_vb1, _va2, _sum2); // sum2 += (a10-a17) * k21 _sum3 = _mm256_fmadd_ps(_vb1, _va3, _sum3); // sum3 += (a10-a17) * k31 _va0 = _mm256_broadcast_ss(va + 4); _va1 = _mm256_broadcast_ss(va + 5); _va2 = _mm256_broadcast_ss(va + 6); _va3 = _mm256_broadcast_ss(va + 7); _sum4 = _mm256_fmadd_ps(_vb1, _va0, _sum4); // sum4 += (a10-a17) * k41 _sum5 = _mm256_fmadd_ps(_vb1, _va1, _sum5); // sum5 += (a10-a17) * k51 _sum6 = _mm256_fmadd_ps(_vb1, _va2, _sum6); // sum6 += (a10-a17) * k61 _sum7 = _mm256_fmadd_ps(_vb1, _va3, _sum7); // sum7 += (a10-a17) * k71 va += 8; // k2 _va0 = _mm256_broadcast_ss(va); _va1 = _mm256_broadcast_ss(va + 1); _va2 = _mm256_broadcast_ss(va + 2); _va3 = _mm256_broadcast_ss(va + 3); _sum0 = _mm256_fmadd_ps(_vb2, _va0, _sum0); // sum0 += (a20-a27) * k02 _sum1 = _mm256_fmadd_ps(_vb2, _va1, _sum1); // sum1 += (a20-a27) * k12 _sum2 = _mm256_fmadd_ps(_vb2, _va2, _sum2); // sum2 += (a20-a27) * k22 _sum3 = _mm256_fmadd_ps(_vb2, _va3, _sum3); // sum3 += (a20-a27) * k32 _va0 = _mm256_broadcast_ss(va + 4); _va1 = _mm256_broadcast_ss(va + 5); _va2 = _mm256_broadcast_ss(va + 6); _va3 = _mm256_broadcast_ss(va + 7); _sum4 = _mm256_fmadd_ps(_vb2, _va0, _sum4); // sum4 += (a20-a27) * k42 _sum5 = _mm256_fmadd_ps(_vb2, _va1, _sum5); // sum5 += (a20-a27) * k52 _sum6 = _mm256_fmadd_ps(_vb2, _va2, _sum6); // sum6 += (a20-a27) * k62 _sum7 = _mm256_fmadd_ps(_vb2, _va3, _sum7); // sum7 += (a20-a27) * k72 va += 8; // k3 _va0 = _mm256_broadcast_ss(va); _va1 = _mm256_broadcast_ss(va + 1); _va2 = _mm256_broadcast_ss(va + 2); _va3 = _mm256_broadcast_ss(va + 3); _sum0 = _mm256_fmadd_ps(_vb3, _va0, _sum0); // sum0 += (a30-a37) * k03 _sum1 = _mm256_fmadd_ps(_vb3, _va1, _sum1); // sum1 += (a30-a37) * k13 _sum2 = _mm256_fmadd_ps(_vb3, _va2, _sum2); // sum2 += (a30-a37) * k23 _sum3 = _mm256_fmadd_ps(_vb3, _va3, _sum3); // sum3 += (a30-a37) * k33 _va0 = _mm256_broadcast_ss(va + 4); _va1 = _mm256_broadcast_ss(va + 5); _va2 = _mm256_broadcast_ss(va + 6); _va3 = _mm256_broadcast_ss(va + 7); _sum4 = _mm256_fmadd_ps(_vb3, _va0, _sum4); // sum4 += (a30-a37) * k43 _sum5 = _mm256_fmadd_ps(_vb3, _va1, _sum5); // sum5 += (a30-a37) * k53 _sum6 = _mm256_fmadd_ps(_vb3, _va2, _sum6); // sum6 += (a30-a37) * k63 _sum7 = _mm256_fmadd_ps(_vb3, _va3, _sum7); // sum7 += (a30-a37) * k73 va += 8; vb += 32; } for (; k < L; k++) { // k0 __m256 _va0 = _mm256_broadcast_ss(va); __m256 _va1 = _mm256_broadcast_ss(va + 1); __m256 _va2 = _mm256_broadcast_ss(va + 2); __m256 _va3 = _mm256_broadcast_ss(va + 3); __m256 _va4 = _mm256_broadcast_ss(va + 4); __m256 _va5 = _mm256_broadcast_ss(va + 5); __m256 _va6 = _mm256_broadcast_ss(va + 6); __m256 _va7 = _mm256_broadcast_ss(va + 7); __m256 _vb0 = _mm256_loadu_ps(vb); _sum0 = _mm256_fmadd_ps(_vb0, _va0, _sum0); // sum0 = (a00-a07) * k00 _sum1 = _mm256_fmadd_ps(_vb0, _va1, _sum1); // sum1 = (a00-a07) * k10 _sum2 = _mm256_fmadd_ps(_vb0, _va2, _sum2); // sum2 = (a00-a07) * k20 _sum3 = _mm256_fmadd_ps(_vb0, _va3, _sum3); // sum3 = (a00-a07) * k30 _sum4 = _mm256_fmadd_ps(_vb0, _va4, _sum4); // sum4 = (a00-a07) * k40 _sum5 = _mm256_fmadd_ps(_vb0, _va5, _sum5); // sum5 = (a00-a07) * k50 _sum6 = _mm256_fmadd_ps(_vb0, _va6, _sum6); // sum6 = (a00-a07) * k60 _sum7 = _mm256_fmadd_ps(_vb0, _va7, _sum7); // sum7 = (a00-a07) * k70 va += 8; vb += 8; } _mm256_storeu_ps(output0, _sum0); _mm256_storeu_ps(output1, _sum1); _mm256_storeu_ps(output2, _sum2); _mm256_storeu_ps(output3, _sum3); _mm256_storeu_ps(output4, _sum4); _mm256_storeu_ps(output5, _sum5); _mm256_storeu_ps(output6, _sum6); _mm256_storeu_ps(output7, _sum7); #else float sum0[8] = {0}; float sum1[8] = {0}; float sum2[8] = {0}; float sum3[8] = {0}; float sum4[8] = {0}; float sum5[8] = {0}; float sum6[8] = {0}; float sum7[8] = {0}; int k = 0; for (; k + 7 < L; k = k + 8) { for (int n = 0; n < 8; n++) { sum0[n] += va[0] * vb[n]; sum1[n] += va[1] * vb[n]; sum2[n] += va[2] * vb[n]; sum3[n] += va[3] * vb[n]; sum4[n] += va[4] * vb[n]; sum5[n] += va[5] * vb[n]; sum6[n] += va[6] * vb[n]; sum7[n] += va[7] * vb[n]; va += 8; sum0[n] += va[0] * vb[n + 8]; sum1[n] += va[1] * vb[n + 8]; sum2[n] += va[2] * vb[n + 8]; sum3[n] += va[3] * vb[n + 8]; sum4[n] += va[4] * vb[n + 8]; sum5[n] += va[5] * vb[n + 8]; sum6[n] += va[6] * vb[n + 8]; sum7[n] += va[7] * vb[n + 8]; va += 8; sum0[n] += va[0] * vb[n + 16]; sum1[n] += va[1] * vb[n + 16]; sum2[n] += va[2] * vb[n + 16]; sum3[n] += va[3] * vb[n + 16]; sum4[n] += va[4] * vb[n + 16]; sum5[n] += va[5] * vb[n + 16]; sum6[n] += va[6] * vb[n + 16]; sum7[n] += va[7] * vb[n + 16]; va += 8; sum0[n] += va[0] * vb[n + 24]; sum1[n] += va[1] * vb[n + 24]; sum2[n] += va[2] * vb[n + 24]; sum3[n] += va[3] * vb[n + 24]; sum4[n] += va[4] * vb[n + 24]; sum5[n] += va[5] * vb[n + 24]; sum6[n] += va[6] * vb[n + 24]; sum7[n] += va[7] * vb[n + 24]; va += 8; sum0[n] += va[0] * vb[n + 32]; sum1[n] += va[1] * vb[n + 32]; sum2[n] += va[2] * vb[n + 32]; sum3[n] += va[3] * vb[n + 32]; sum4[n] += va[4] * vb[n + 32]; sum5[n] += va[5] * vb[n + 32]; sum6[n] += va[6] * vb[n + 32]; sum7[n] += va[7] * vb[n + 32]; va += 8; sum0[n] += va[0] * vb[n + 40]; sum1[n] += va[1] * vb[n + 40]; sum2[n] += va[2] * vb[n + 40]; sum3[n] += va[3] * vb[n + 40]; sum4[n] += va[4] * vb[n + 40]; sum5[n] += va[5] * vb[n + 40]; sum6[n] += va[6] * vb[n + 40]; sum7[n] += va[7] * vb[n + 40]; va += 8; sum0[n] += va[0] * vb[n + 48]; sum1[n] += va[1] * vb[n + 48]; sum2[n] += va[2] * vb[n + 48]; sum3[n] += va[3] * vb[n + 48]; sum4[n] += va[4] * vb[n + 48]; sum5[n] += va[5] * vb[n + 48]; sum6[n] += va[6] * vb[n + 48]; sum7[n] += va[7] * vb[n + 48]; va += 8; sum0[n] += va[0] * vb[n + 56]; sum1[n] += va[1] * vb[n + 56]; sum2[n] += va[2] * vb[n + 56]; sum3[n] += va[3] * vb[n + 56]; sum4[n] += va[4] * vb[n + 56]; sum5[n] += va[5] * vb[n + 56]; sum6[n] += va[6] * vb[n + 56]; sum7[n] += va[7] * vb[n + 56]; va -= 56; } va += 64; vb += 64; } for (; k < L; k++) { for (int n = 0; n < 8; n++) { sum0[n] += va[0] * vb[n]; sum1[n] += va[1] * vb[n]; sum2[n] += va[2] * vb[n]; sum3[n] += va[3] * vb[n]; sum4[n] += va[4] * vb[n]; sum5[n] += va[5] * vb[n]; sum6[n] += va[6] * vb[n]; sum7[n] += va[7] * vb[n]; } va += 8; vb += 8; } for (int n = 0; n < 8; n++) { output0[n] = sum0[n] + biasptr[0]; output1[n] = sum1[n] + biasptr[1]; output2[n] = sum2[n] + biasptr[2]; output3[n] = sum3[n] + biasptr[3]; output4[n] = sum4[n] + biasptr[4]; output5[n] = sum5[n] + biasptr[5]; output6[n] = sum6[n] + biasptr[6]; output7[n] = sum7[n] + biasptr[7]; } #endif // __AVX__ output0 += 8; output1 += 8; output2 += 8; output3 += 8; output4 += 8; output5 += 8; output6 += 8; output7 += 8; } for (; j < N; j++) { const float* vb = bottom_tm.channel(j / 8 + j % 8); const float* va = kernel_tm.channel(i / 8); #if __AVX__ __m256 _sum0_7 = _mm256_loadu_ps(biasptr); __m256 _sum0 = _mm256_set1_ps(0.0); __m256 _sum1 = _mm256_set1_ps(0.0); __m256 _sum2 = _mm256_set1_ps(0.0); __m256 _sum3 = _mm256_set1_ps(0.0); int k = 0; for (; k + 3 < L; k = k + 4) { __m256 _vb0 = _mm256_broadcast_ss(vb); __m256 _vb1 = _mm256_broadcast_ss(vb + 1); __m256 _vb2 = _mm256_broadcast_ss(vb + 2); __m256 _vb3 = _mm256_broadcast_ss(vb + 3); __m256 _va0 = _mm256_loadu_ps(va); __m256 _va1 = _mm256_loadu_ps(va + 8); __m256 _va2 = _mm256_loadu_ps(va + 16); __m256 _va3 = _mm256_loadu_ps(va + 24); _sum0 = _mm256_fmadd_ps(_va0, _vb0, _sum0); // sum0 += (k00-k70) * a00 _sum1 = _mm256_fmadd_ps(_va1, _vb1, _sum1); // sum1 += (k01-k71) * a10 _sum2 = _mm256_fmadd_ps(_va2, _vb2, _sum2); // sum2 += (k02-k72) * a20 _sum3 = _mm256_fmadd_ps(_va3, _vb3, _sum3); // sum3 += (k03-k73) * a30 va += 32; vb += 4; } _sum0 = _mm256_add_ps(_sum0, _sum1); _sum2 = _mm256_add_ps(_sum2, _sum3); _sum0_7 = _mm256_add_ps(_sum0_7, _sum0); _sum0_7 = _mm256_add_ps(_sum0_7, _sum2); for (; k < L; k++) { __m256 _vb0 = _mm256_broadcast_ss(vb); __m256 _va = _mm256_loadu_ps(va); _sum0_7 = _mm256_fmadd_ps(_va, _vb0, _sum0_7); // sum0 += (k00-k70) * a00 va += 8; vb += 1; } float output_sum0_7[8] = {0.f}; _mm256_storeu_ps(output_sum0_7, _sum0_7); output0[0] = output_sum0_7[0]; output1[0] = output_sum0_7[1]; output2[0] = output_sum0_7[2]; output3[0] = output_sum0_7[3]; output4[0] = output_sum0_7[4]; output5[0] = output_sum0_7[5]; output6[0] = output_sum0_7[6]; output7[0] = output_sum0_7[7]; #else float sum0 = biasptr[0]; float sum1 = biasptr[1]; float sum2 = biasptr[2]; float sum3 = biasptr[3]; float sum4 = biasptr[4]; float sum5 = biasptr[5]; float sum6 = biasptr[6]; float sum7 = biasptr[7]; for (int k = 0; k < L; k++) { sum0 += va[0] * vb[0]; sum1 += va[1] * vb[0]; sum2 += va[2] * vb[0]; sum3 += va[3] * vb[0]; sum4 += va[4] * vb[0]; sum5 += va[5] * vb[0]; sum6 += va[6] * vb[0]; sum7 += va[7] * vb[0]; va += 8; vb += 1; } output0[0] = sum0; output1[0] = sum1; output2[0] = sum2; output3[0] = sum3; output4[0] = sum4; output5[0] = sum5; output6[0] = sum6; output7[0] = sum7; #endif // __AVX__ output0++; output1++; output2++; output3++; output4++; output5++; output6++; output7++; } } nn_outch = (outch - remain_outch_start) >> 2; #pragma omp parallel for num_threads(opt.num_threads) for (int pp = 0; pp < nn_outch; pp++) { int i = remain_outch_start + pp * 4; float* output0 = top_blob.channel(i); float* output1 = top_blob.channel(i + 1); float* output2 = top_blob.channel(i + 2); float* output3 = top_blob.channel(i + 3); const float zeros[4] = {0.f, 0.f, 0.f, 0.f}; const float* biasptr = bias ? bias + i : zeros; int j = 0; for (; j + 7 < N; j = j + 8) { const float* vb = bottom_tm.channel(j / 8); const float* va = kernel_tm.channel(i / 8 + (i % 8) / 4); #if __AVX__ __m256 _sum0 = _mm256_broadcast_ss(biasptr); __m256 _sum1 = _mm256_broadcast_ss(biasptr + 1); __m256 _sum2 = _mm256_broadcast_ss(biasptr + 2); __m256 _sum3 = _mm256_broadcast_ss(biasptr + 3); int k = 0; for (; k + 3 < L; k = k + 4) { // k0 __m256 _va0 = _mm256_broadcast_ss(va); __m256 _va1 = _mm256_broadcast_ss(va + 1); __m256 _va2 = _mm256_broadcast_ss(va + 2); __m256 _va3 = _mm256_broadcast_ss(va + 3); __m256 _vb0 = _mm256_loadu_ps(vb); __m256 _vb1 = _mm256_loadu_ps(vb + 8); __m256 _vb2 = _mm256_loadu_ps(vb + 16); __m256 _vb3 = _mm256_loadu_ps(vb + 24); _sum0 = _mm256_fmadd_ps(_vb0, _va0, _sum0); // sum0 = (a00-a07) * k00 _sum1 = _mm256_fmadd_ps(_vb0, _va1, _sum1); // sum1 = (a00-a07) * k10 _sum2 = _mm256_fmadd_ps(_vb0, _va2, _sum2); // sum2 = (a00-a07) * k20 _sum3 = _mm256_fmadd_ps(_vb0, _va3, _sum3); // sum3 = (a00-a07) * k30 va += 4; // k1 _va0 = _mm256_broadcast_ss(va); _va1 = _mm256_broadcast_ss(va + 1); _va2 = _mm256_broadcast_ss(va + 2); _va3 = _mm256_broadcast_ss(va + 3); _sum0 = _mm256_fmadd_ps(_vb1, _va0, _sum0); // sum0 += (a10-a17) * k01 _sum1 = _mm256_fmadd_ps(_vb1, _va1, _sum1); // sum1 += (a10-a17) * k11 _sum2 = _mm256_fmadd_ps(_vb1, _va2, _sum2); // sum2 += (a10-a17) * k21 _sum3 = _mm256_fmadd_ps(_vb1, _va3, _sum3); // sum3 += (a10-a17) * k31 va += 4; // k2 _va0 = _mm256_broadcast_ss(va); _va1 = _mm256_broadcast_ss(va + 1); _va2 = _mm256_broadcast_ss(va + 2); _va3 = _mm256_broadcast_ss(va + 3); _sum0 = _mm256_fmadd_ps(_vb2, _va0, _sum0); // sum0 += (a20-a27) * k02 _sum1 = _mm256_fmadd_ps(_vb2, _va1, _sum1); // sum1 += (a20-a27) * k12 _sum2 = _mm256_fmadd_ps(_vb2, _va2, _sum2); // sum2 += (a20-a27) * k22 _sum3 = _mm256_fmadd_ps(_vb2, _va3, _sum3); // sum3 += (a20-a27) * k32 va += 4; // k3 _va0 = _mm256_broadcast_ss(va); _va1 = _mm256_broadcast_ss(va + 1); _va2 = _mm256_broadcast_ss(va + 2); _va3 = _mm256_broadcast_ss(va + 3); _sum0 = _mm256_fmadd_ps(_vb3, _va0, _sum0); // sum0 += (a30-a37) * k03 _sum1 = _mm256_fmadd_ps(_vb3, _va1, _sum1); // sum1 += (a30-a37) * k13 _sum2 = _mm256_fmadd_ps(_vb3, _va2, _sum2); // sum2 += (a30-a37) * k23 _sum3 = _mm256_fmadd_ps(_vb3, _va3, _sum3); // sum3 += (a30-a37) * k33 va += 4; vb += 32; } for (; k < L; k++) { // k0 __m256 _va0 = _mm256_broadcast_ss(va); __m256 _va1 = _mm256_broadcast_ss(va + 1); __m256 _va2 = _mm256_broadcast_ss(va + 2); __m256 _va3 = _mm256_broadcast_ss(va + 3); __m256 _vb0 = _mm256_loadu_ps(vb); _sum0 = _mm256_fmadd_ps(_vb0, _va0, _sum0); // sum0 = (a00-a07) * k00 _sum1 = _mm256_fmadd_ps(_vb0, _va1, _sum1); // sum1 = (a00-a07) * k10 _sum2 = _mm256_fmadd_ps(_vb0, _va2, _sum2); // sum2 = (a00-a07) * k20 _sum3 = _mm256_fmadd_ps(_vb0, _va3, _sum3); // sum3 = (a00-a07) * k30 va += 4; vb += 8; } _mm256_storeu_ps(output0, _sum0); _mm256_storeu_ps(output1, _sum1); _mm256_storeu_ps(output2, _sum2); _mm256_storeu_ps(output3, _sum3); #else float sum0[8] = {0}; float sum1[8] = {0}; float sum2[8] = {0}; float sum3[8] = {0}; int k = 0; for (; k + 7 < L; k = k + 8) { for (int n = 0; n < 8; n++) { sum0[n] += va[0] * vb[n]; sum1[n] += va[1] * vb[n]; sum2[n] += va[2] * vb[n]; sum3[n] += va[3] * vb[n]; va += 4; sum0[n] += va[0] * vb[n + 8]; sum1[n] += va[1] * vb[n + 8]; sum2[n] += va[2] * vb[n + 8]; sum3[n] += va[3] * vb[n + 8]; va += 4; sum0[n] += va[0] * vb[n + 16]; sum1[n] += va[1] * vb[n + 16]; sum2[n] += va[2] * vb[n + 16]; sum3[n] += va[3] * vb[n + 16]; va += 4; sum0[n] += va[0] * vb[n + 24]; sum1[n] += va[1] * vb[n + 24]; sum2[n] += va[2] * vb[n + 24]; sum3[n] += va[3] * vb[n + 24]; va += 4; sum0[n] += va[0] * vb[n + 32]; sum1[n] += va[1] * vb[n + 32]; sum2[n] += va[2] * vb[n + 32]; sum3[n] += va[3] * vb[n + 32]; va += 4; sum0[n] += va[0] * vb[n + 40]; sum1[n] += va[1] * vb[n + 40]; sum2[n] += va[2] * vb[n + 40]; sum3[n] += va[3] * vb[n + 40]; va += 4; sum0[n] += va[0] * vb[n + 48]; sum1[n] += va[1] * vb[n + 48]; sum2[n] += va[2] * vb[n + 48]; sum3[n] += va[3] * vb[n + 48]; va += 4; sum0[n] += va[0] * vb[n + 56]; sum1[n] += va[1] * vb[n + 56]; sum2[n] += va[2] * vb[n + 56]; sum3[n] += va[3] * vb[n + 56]; va -= 28; } va += 32; vb += 64; } for (; k < L; k++) { for (int n = 0; n < 8; n++) { sum0[n] += va[0] * vb[n]; sum1[n] += va[1] * vb[n]; sum2[n] += va[2] * vb[n]; sum3[n] += va[3] * vb[n]; } va += 4; vb += 8; } for (int n = 0; n < 8; n++) { output0[n] = sum0[n] + biasptr[0]; output1[n] = sum1[n] + biasptr[1]; output2[n] = sum2[n] + biasptr[2]; output3[n] = sum3[n] + biasptr[3]; } #endif // __AVX__ output0 += 8; output1 += 8; output2 += 8; output3 += 8; } for (; j < N; j++) { const float* vb = bottom_tm.channel(j / 8 + j % 8); const float* va = kernel_tm.channel(i / 8 + (i % 8) / 4); #if __AVX__ __m128 _sum0_3 = _mm_loadu_ps(biasptr); __m128 _sum0 = _mm_set1_ps(0.0); __m128 _sum1 = _mm_set1_ps(0.0); __m128 _sum2 = _mm_set1_ps(0.0); __m128 _sum3 = _mm_set1_ps(0.0); int k = 0; for (; k + 3 < L; k = k + 4) { __m128 _vb0 = _mm_set1_ps(vb[0]); __m128 _vb1 = _mm_set1_ps(vb[1]); __m128 _vb2 = _mm_set1_ps(vb[2]); __m128 _vb3 = _mm_set1_ps(vb[3]); __m128 _va0 = _mm_loadu_ps(va); __m128 _va1 = _mm_loadu_ps(va + 4); __m128 _va2 = _mm_loadu_ps(va + 8); __m128 _va3 = _mm_loadu_ps(va + 12); _sum0 = _mm_fmadd_ps(_va0, _vb0, _sum0); // sum0 += (k00-k30) * a00 _sum1 = _mm_fmadd_ps(_va1, _vb1, _sum1); // sum1 += (k01-k31) * a10 _sum2 = _mm_fmadd_ps(_va2, _vb2, _sum2); // sum2 += (k02-k32) * a20 _sum3 = _mm_fmadd_ps(_va3, _vb3, _sum3); // sum3 += (k03-k33) * a30 va += 16; vb += 4; } _sum0 = _mm_add_ps(_sum0, _sum1); _sum2 = _mm_add_ps(_sum2, _sum3); _sum0_3 = _mm_add_ps(_sum0_3, _sum0); _sum0_3 = _mm_add_ps(_sum0_3, _sum2); for (; k < L; k++) { __m128 _vb0 = _mm_set1_ps(vb[0]); __m128 _va = _mm_loadu_ps(va); _sum0_3 = _mm_fmadd_ps(_va, _vb0, _sum0_3); // sum0 += (k00-k30) * a00 va += 4; vb += 1; } float output_sum0_3[4] = {0.f}; _mm_storeu_ps(output_sum0_3, _sum0_3); output0[0] = output_sum0_3[0]; output1[0] = output_sum0_3[1]; output2[0] = output_sum0_3[2]; output3[0] = output_sum0_3[3]; #else float sum0 = biasptr[0]; float sum1 = biasptr[1]; float sum2 = biasptr[2]; float sum3 = biasptr[3]; for (int k = 0; k < L; k++) { sum0 += va[0] * vb[0]; sum1 += va[1] * vb[0]; sum2 += va[2] * vb[0]; sum3 += va[3] * vb[0]; va += 4; vb += 1; } output0[0] = sum0; output1[0] = sum1; output2[0] = sum2; output3[0] = sum3; #endif // __AVX__ output0++; output1++; output2++; output3++; } } remain_outch_start += nn_outch << 2; #pragma omp parallel for num_threads(opt.num_threads) for (int i = remain_outch_start; i < outch; i++) { float* output = top_blob.channel(i); const float bias0 = bias ? bias[i] : 0.f; int j = 0; for (; j + 7 < N; j = j + 8) { const float* vb = bottom_tm.channel(j / 8); const float* va = kernel_tm.channel(i / 8 + (i % 8) / 4 + i % 4); #if __AVX__ __m256 _sum0 = _mm256_broadcast_ss(&bias0); int k = 0; for (; k + 3 < L; k = k + 4) { // k0 __m256 _va0 = _mm256_broadcast_ss(va); __m256 _va1 = _mm256_broadcast_ss(va + 1); __m256 _va2 = _mm256_broadcast_ss(va + 2); __m256 _va3 = _mm256_broadcast_ss(va + 3); __m256 _vb0 = _mm256_loadu_ps(vb); __m256 _vb1 = _mm256_loadu_ps(vb + 8); __m256 _vb2 = _mm256_loadu_ps(vb + 16); __m256 _vb3 = _mm256_loadu_ps(vb + 24); _sum0 = _mm256_fmadd_ps(_vb0, _va0, _sum0); // sum0 = (a00-a07) * k00 _sum0 = _mm256_fmadd_ps(_vb1, _va1, _sum0); // sum0 += (a10-a17) * k01 _sum0 = _mm256_fmadd_ps(_vb2, _va2, _sum0); // sum0 += (a20-a27) * k02 _sum0 = _mm256_fmadd_ps(_vb3, _va3, _sum0); // sum0 += (a30-a37) * k03 va += 4; vb += 32; } for (; k < L; k++) { // k0 __m256 _va0 = _mm256_broadcast_ss(va); __m256 _vb0 = _mm256_loadu_ps(vb); _sum0 = _mm256_fmadd_ps(_vb0, _va0, _sum0); // sum0 = (a00-a07) * k00 va += 1; vb += 8; } _mm256_storeu_ps(output, _sum0); #else float sum[8] = {0}; int k = 0; for (; k + 7 < L; k = k + 8) { for (int n = 0; n < 8; n++) { sum[n] += va[0] * vb[n]; sum[n] += va[1] * vb[n + 8]; sum[n] += va[2] * vb[n + 16]; sum[n] += va[3] * vb[n + 24]; sum[n] += va[4] * vb[n + 32]; sum[n] += va[5] * vb[n + 40]; sum[n] += va[6] * vb[n + 48]; sum[n] += va[7] * vb[n + 56]; } va += 8; vb += 64; } for (; k < L; k++) { for (int n = 0; n < 8; n++) { sum[n] += va[0] * vb[n]; } va += 1; vb += 8; } for (int n = 0; n < 8; n++) { output[n] = sum[n] + bias0; } #endif // __AVX__ output += 8; } for (; j < N; j++) { const float* vb = bottom_tm.channel(j / 8 + j % 8); const float* va = kernel_tm.channel(i / 8 + (i % 8) / 4 + i % 4); int k = 0; #if __AVX__ __m128 _sum0 = _mm_set1_ps(0.f); for (; k + 3 < L; k += 4) { __m128 _p0 = _mm_loadu_ps(vb); vb += 4; __m128 _k0 = _mm_loadu_ps(va); va += 4; _sum0 = _mm_fmadd_ps(_p0, _k0, _sum0); } float output_sum0[4] = {0.f}; _mm_storeu_ps(output_sum0, _sum0); float sum0 = bias0 + output_sum0[0] + output_sum0[1] + output_sum0[2] + output_sum0[3]; #else float sum0 = bias0; #endif // __AVX__ for (; k < L; k++) { sum0 += va[0] * vb[0]; va += 1; vb += 1; } output[0] = sum0; output++; } } } } #else static void conv_im2col_sgemm_transform_kernel_sse(const Mat& _kernel, Mat& kernel_tm, int inch, int outch, int kernel_size) { const float* kernel = _kernel; // kernel memory packed 4 x 4 kernel_tm.create(4 * kernel_size, inch, outch / 4 + outch % 4); int nn_outch = 0; int remain_outch_start = 0; nn_outch = outch >> 2; remain_outch_start = nn_outch << 2; for (int pp = 0; pp < nn_outch; pp++) { int p = pp * 4; const float* k0 = kernel + (p + 0) * inch * kernel_size; const float* k1 = kernel + (p + 1) * inch * kernel_size; const float* k2 = kernel + (p + 2) * inch * kernel_size; const float* k3 = kernel + (p + 3) * inch * kernel_size; float* ktmp = kernel_tm.channel(p / 4); for (int q = 0; q < inch * kernel_size; q++) { ktmp[0] = k0[0]; ktmp[1] = k1[0]; ktmp[2] = k2[0]; ktmp[3] = k3[0]; ktmp += 4; k0 += 1; k1 += 1; k2 += 1; k3 += 1; } } for (int p = remain_outch_start; p < outch; p++) { const float* k0 = kernel + (p + 0) * inch * kernel_size; float* ktmp = kernel_tm.channel(p / 4 + p % 4); for (int q = 0; q < inch * kernel_size; q++) { ktmp[0] = k0[0]; ktmp++; k0++; } } } static void conv_im2col_sgemm_sse(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel_tm, const Mat& _bias, const int kernel_w, const int kernel_h, const int stride_w, const int stride_h, const Option& opt) { int w = bottom_blob.w; int inch = bottom_blob.c; size_t elemsize = bottom_blob.elemsize; int outw = top_blob.w; int outh = top_blob.h; int outch = top_blob.c; const float* bias = _bias; // im2col Mat bottom_im2col(outw * outh, kernel_h * kernel_w * inch, elemsize, opt.workspace_allocator); { const int stride = kernel_h * kernel_w * outw * outh; float* ret = (float*)bottom_im2col; #pragma omp parallel for num_threads(opt.num_threads) for (int p = 0; p < inch; p++) { const float* input = bottom_blob.channel(p); int retID = stride * p; for (int u = 0; u < kernel_h; u++) { for (int v = 0; v < kernel_w; v++) { for (int i = 0; i < outh; i++) { for (int j = 0; j < outw; j++) { int row = u + i * stride_h; int col = v + j * stride_w; int index = row * w + col; ret[retID] = input[index]; retID++; } } } } } } int kernel_size = kernel_w * kernel_h; int out_size = outw * outh; // bottom_im2col memory packed 4 x 4 Mat bottom_tm(4 * kernel_size, inch, out_size / 4 + out_size % 4, elemsize, opt.workspace_allocator); { int nn_size = out_size >> 2; int remain_size_start = nn_size << 2; #pragma omp parallel for num_threads(opt.num_threads) for (int ii = 0; ii < nn_size; ii++) { int i = ii * 4; const float* img0 = bottom_im2col.channel(0); img0 += i; float* tmpptr = bottom_tm.channel(i / 4); for (int q = 0; q < inch * kernel_size; q++) { #if __SSE__ _mm_storeu_ps(tmpptr, _mm_loadu_ps(img0)); #else tmpptr[0] = img0[0]; tmpptr[1] = img0[1]; tmpptr[2] = img0[2]; tmpptr[3] = img0[3]; #endif // __SSE__ tmpptr += 4; img0 += out_size; } } #pragma omp parallel for num_threads(opt.num_threads) for (int i = remain_size_start; i < out_size; i++) { const float* img0 = bottom_im2col.channel(0); img0 += i; float* tmpptr = bottom_tm.channel(i / 4 + i % 4); for (int q = 0; q < inch * kernel_size; q++) { tmpptr[0] = img0[0]; tmpptr += 1; img0 += out_size; } } } // sgemm(int M, int N, int L, float* A, float* B, float* C) { //int M = outch; // outch int N = outw * outh; // outsize or out stride int L = kernel_w * kernel_h * inch; // ksize * inch int nn_outch = 0; int remain_outch_start = 0; nn_outch = outch >> 2; remain_outch_start = nn_outch << 2; #pragma omp parallel for num_threads(opt.num_threads) for (int pp = 0; pp < nn_outch; pp++) { int i = pp * 4; float* output0 = top_blob.channel(i); float* output1 = top_blob.channel(i + 1); float* output2 = top_blob.channel(i + 2); float* output3 = top_blob.channel(i + 3); const float zeros[4] = {0.f, 0.f, 0.f, 0.f}; const float* biasptr = bias ? bias + i : zeros; int j = 0; for (; j + 3 < N; j = j + 4) { const float* vb = bottom_tm.channel(j / 4); const float* va = kernel_tm.channel(i / 4); #if __SSE__ __m128 _sum0 = _mm_set1_ps(biasptr[0]); __m128 _sum1 = _mm_set1_ps(biasptr[1]); __m128 _sum2 = _mm_set1_ps(biasptr[2]); __m128 _sum3 = _mm_set1_ps(biasptr[3]); int k = 0; for (; k + 3 < L; k = k + 4) { // k0 __m128 _vb = _mm_loadu_ps(vb); __m128 _va0 = _mm_set1_ps(va[0]); __m128 _va1 = _mm_set1_ps(va[1]); __m128 _va2 = _mm_set1_ps(va[2]); __m128 _va3 = _mm_set1_ps(va[3]); _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_vb, _va0)); // sum0 = (a00-a03) * k00 _sum1 = _mm_add_ps(_sum1, _mm_mul_ps(_vb, _va1)); // sum1 = (a00-a03) * k10 _sum2 = _mm_add_ps(_sum2, _mm_mul_ps(_vb, _va2)); // sum2 = (a00-a03) * k20 _sum3 = _mm_add_ps(_sum3, _mm_mul_ps(_vb, _va3)); // sum3 = (a00-a03) * k30 // k1 _vb = _mm_loadu_ps(vb + 4); _va0 = _mm_set1_ps(va[4]); _va1 = _mm_set1_ps(va[5]); _va2 = _mm_set1_ps(va[6]); _va3 = _mm_set1_ps(va[7]); _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_vb, _va0)); // sum0 = (a10-a13) * k01 _sum1 = _mm_add_ps(_sum1, _mm_mul_ps(_vb, _va1)); // sum1 = (a10-a13) * k11 _sum2 = _mm_add_ps(_sum2, _mm_mul_ps(_vb, _va2)); // sum2 = (a10-a13) * k21 _sum3 = _mm_add_ps(_sum3, _mm_mul_ps(_vb, _va3)); // sum3 = (a10-a13) * k31 // k2 _vb = _mm_loadu_ps(vb + 8); _va0 = _mm_set1_ps(va[8]); _va1 = _mm_set1_ps(va[9]); _va2 = _mm_set1_ps(va[10]); _va3 = _mm_set1_ps(va[11]); _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_vb, _va0)); // sum0 = (a20-a23) * k02 _sum1 = _mm_add_ps(_sum1, _mm_mul_ps(_vb, _va1)); // sum1 = (a20-a23) * k12 _sum2 = _mm_add_ps(_sum2, _mm_mul_ps(_vb, _va2)); // sum2 = (a20-a23) * k22 _sum3 = _mm_add_ps(_sum3, _mm_mul_ps(_vb, _va3)); // sum3 = (a20-a23) * k32 // k3 _vb = _mm_loadu_ps(vb + 12); _va0 = _mm_set1_ps(va[12]); _va1 = _mm_set1_ps(va[13]); _va2 = _mm_set1_ps(va[14]); _va3 = _mm_set1_ps(va[15]); _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_vb, _va0)); // sum0 = (a30-a33) * k03 _sum1 = _mm_add_ps(_sum1, _mm_mul_ps(_vb, _va1)); // sum1 = (a30-a33) * k13 _sum2 = _mm_add_ps(_sum2, _mm_mul_ps(_vb, _va2)); // sum2 = (a30-a33) * k23 _sum3 = _mm_add_ps(_sum3, _mm_mul_ps(_vb, _va3)); // sum3 = (a30-a33) * k33 va += 16; vb += 16; } for (; k < L; k++) { // k0 __m128 _vb = _mm_loadu_ps(vb); __m128 _va0 = _mm_set1_ps(va[0]); __m128 _va1 = _mm_set1_ps(va[1]); __m128 _va2 = _mm_set1_ps(va[2]); __m128 _va3 = _mm_set1_ps(va[3]); _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_vb, _va0)); // sum0 = (a00-a03) * k00 _sum1 = _mm_add_ps(_sum1, _mm_mul_ps(_vb, _va1)); // sum1 = (a00-a03) * k10 _sum2 = _mm_add_ps(_sum2, _mm_mul_ps(_vb, _va2)); // sum2 = (a00-a03) * k20 _sum3 = _mm_add_ps(_sum3, _mm_mul_ps(_vb, _va3)); // sum3 = (a00-a03) * k30 va += 4; vb += 4; } _mm_storeu_ps(output0, _sum0); _mm_storeu_ps(output1, _sum1); _mm_storeu_ps(output2, _sum2); _mm_storeu_ps(output3, _sum3); #else float sum0[4] = {0}; float sum1[4] = {0}; float sum2[4] = {0}; float sum3[4] = {0}; int k = 0; for (; k + 7 < L; k = k + 8) { for (int n = 0; n < 4; n++) { sum0[n] += va[0] * vb[n]; sum1[n] += va[1] * vb[n]; sum2[n] += va[2] * vb[n]; sum3[n] += va[3] * vb[n]; va += 4; sum0[n] += va[0] * vb[n + 4]; sum1[n] += va[1] * vb[n + 4]; sum2[n] += va[2] * vb[n + 4]; sum3[n] += va[3] * vb[n + 4]; va += 4; sum0[n] += va[0] * vb[n + 8]; sum1[n] += va[1] * vb[n + 8]; sum2[n] += va[2] * vb[n + 8]; sum3[n] += va[3] * vb[n + 8]; va += 4; sum0[n] += va[0] * vb[n + 12]; sum1[n] += va[1] * vb[n + 12]; sum2[n] += va[2] * vb[n + 12]; sum3[n] += va[3] * vb[n + 12]; va += 4; sum0[n] += va[0] * vb[n + 16]; sum1[n] += va[1] * vb[n + 16]; sum2[n] += va[2] * vb[n + 16]; sum3[n] += va[3] * vb[n + 16]; va += 4; sum0[n] += va[0] * vb[n + 20]; sum1[n] += va[1] * vb[n + 20]; sum2[n] += va[2] * vb[n + 20]; sum3[n] += va[3] * vb[n + 20]; va += 4; sum0[n] += va[0] * vb[n + 24]; sum1[n] += va[1] * vb[n + 24]; sum2[n] += va[2] * vb[n + 24]; sum3[n] += va[3] * vb[n + 24]; va += 4; sum0[n] += va[0] * vb[n + 28]; sum1[n] += va[1] * vb[n + 28]; sum2[n] += va[2] * vb[n + 28]; sum3[n] += va[3] * vb[n + 28]; va -= 28; } va += 32; vb += 32; } for (; k < L; k++) { for (int n = 0; n < 4; n++) { sum0[n] += va[0] * vb[n]; sum1[n] += va[1] * vb[n]; sum2[n] += va[2] * vb[n]; sum3[n] += va[3] * vb[n]; } va += 4; vb += 4; } for (int n = 0; n < 4; n++) { output0[n] = sum0[n] + biasptr[0]; output1[n] = sum1[n] + biasptr[1]; output2[n] = sum2[n] + biasptr[2]; output3[n] = sum3[n] + biasptr[3]; } #endif // __SSE__ output0 += 4; output1 += 4; output2 += 4; output3 += 4; } for (; j < N; j++) { const float* vb = bottom_tm.channel(j / 4 + j % 4); const float* va = kernel_tm.channel(i / 4); #if __SSE__ __m128 _sum0_3 = _mm_loadu_ps(biasptr); __m128 _sum0 = _mm_set1_ps(0.0); __m128 _sum1 = _mm_set1_ps(0.0); __m128 _sum2 = _mm_set1_ps(0.0); __m128 _sum3 = _mm_set1_ps(0.0); int k = 0; for (; k + 3 < L; k = k + 4) { __m128 _vb0 = _mm_set1_ps(vb[0]); __m128 _vb1 = _mm_set1_ps(vb[1]); __m128 _vb2 = _mm_set1_ps(vb[2]); __m128 _vb3 = _mm_set1_ps(vb[3]); __m128 _va0 = _mm_loadu_ps(va); __m128 _va1 = _mm_loadu_ps(va + 4); __m128 _va2 = _mm_loadu_ps(va + 8); __m128 _va3 = _mm_loadu_ps(va + 12); _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_va0, _vb0)); // sum0 += (k00-k30) * a00 _sum1 = _mm_add_ps(_sum1, _mm_mul_ps(_va1, _vb1)); // sum1 += (k01-k31) * a10 _sum2 = _mm_add_ps(_sum2, _mm_mul_ps(_va2, _vb2)); // sum2 += (k02-k32) * a20 _sum3 = _mm_add_ps(_sum3, _mm_mul_ps(_va3, _vb3)); // sum3 += (k03-k33) * a30 va += 16; vb += 4; } _sum0 = _mm_add_ps(_sum0, _sum1); _sum2 = _mm_add_ps(_sum2, _sum3); _sum0_3 = _mm_add_ps(_sum0_3, _sum0); _sum0_3 = _mm_add_ps(_sum0_3, _sum2); for (; k < L; k++) { __m128 _vb0 = _mm_set1_ps(vb[0]); __m128 _va = _mm_loadu_ps(va); _sum0_3 = _mm_add_ps(_sum0_3, _mm_mul_ps(_va, _vb0)); // sum0 += (k00-k30) * a00 va += 4; vb += 1; } float sum0_3_tmp[4]; _mm_storeu_ps(sum0_3_tmp, _sum0_3); output0[0] = sum0_3_tmp[0]; output1[0] = sum0_3_tmp[1]; output2[0] = sum0_3_tmp[2]; output3[0] = sum0_3_tmp[3]; #else float sum0 = biasptr[0]; float sum1 = biasptr[1]; float sum2 = biasptr[2]; float sum3 = biasptr[3]; for (int k = 0; k < L; k++) { sum0 += va[0] * vb[0]; sum1 += va[1] * vb[0]; sum2 += va[2] * vb[0]; sum3 += va[3] * vb[0]; va += 4; vb += 1; } output0[0] = sum0; output1[0] = sum1; output2[0] = sum2; output3[0] = sum3; #endif // __SSE__ output0++; output1++; output2++; output3++; } } #pragma omp parallel for num_threads(opt.num_threads) for (int i = remain_outch_start; i < outch; i++) { float* output = top_blob.channel(i); const float bias0 = bias ? bias[i] : 0.f; int j = 0; for (; j + 3 < N; j = j + 4) { const float* vb = bottom_tm.channel(j / 4); const float* va = kernel_tm.channel(i / 4 + i % 4); #if __SSE__ __m128 _sum0 = _mm_set1_ps(bias0); int k = 0; for (; k + 3 < L; k = k + 4) { // k0 __m128 _va0 = _mm_set1_ps(va[0]); __m128 _va1 = _mm_set1_ps(va[1]); __m128 _va2 = _mm_set1_ps(va[2]); __m128 _va3 = _mm_set1_ps(va[3]); __m128 _vb0 = _mm_loadu_ps(vb); __m128 _vb1 = _mm_loadu_ps(vb + 4); __m128 _vb2 = _mm_loadu_ps(vb + 8); __m128 _vb3 = _mm_loadu_ps(vb + 12); _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_vb0, _va0)); // sum0 = (a00-a03) * k00 _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_vb1, _va1)); // sum0 += (a10-a13) * k01 _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_vb2, _va2)); // sum0 += (a20-a23) * k02 _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_vb3, _va3)); // sum0 += (a30-a33) * k03 va += 4; vb += 16; } for (; k < L; k++) { // k0 __m128 _va0 = _mm_set1_ps(va[0]); __m128 _vb0 = _mm_loadu_ps(vb); _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_vb0, _va0)); // sum0 = (a00-a03) * k00 va += 1; vb += 4; } _mm_storeu_ps(output, _sum0); #else float sum[4] = {0}; int k = 0; for (; k + 3 < L; k = k + 4) { for (int n = 0; n < 4; n++) { sum[n] += va[0] * vb[n]; sum[n] += va[1] * vb[n + 4]; sum[n] += va[2] * vb[n + 8]; sum[n] += va[3] * vb[n + 12]; //sum[n] += va[4] * vb[n+16]; //sum[n] += va[5] * vb[n+20]; //sum[n] += va[6] * vb[n+24]; //sum[n] += va[7] * vb[n+28]; } va += 4; vb += 16; } for (; k < L; k++) { for (int n = 0; n < 4; n++) { sum[n] += va[0] * vb[n]; } va += 1; vb += 4; } for (int n = 0; n < 4; n++) { output[n] = sum[n] + bias0; } #endif // __SSE__ output += 4; } for (; j < N; j++) { const float* vb = bottom_tm.channel(j / 4 + j % 4); const float* va = kernel_tm.channel(i / 4 + i % 4); int k = 0; #if __SSE__ __m128 _sum0 = _mm_set1_ps(0.f); for (; k + 3 < L; k += 4) { __m128 _p0 = _mm_loadu_ps(vb); __m128 _k0 = _mm_loadu_ps(va); _sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_p0, _k0)); va += 4; vb += 4; } float sum0_tmp[4]; _mm_storeu_ps(sum0_tmp, _sum0); float sum0 = bias0 + sum0_tmp[0] + sum0_tmp[1] + sum0_tmp[2] + sum0_tmp[3]; #else float sum0 = bias0; #endif // __SSE__ for (; k < L; k++) { sum0 += va[0] * vb[0]; va += 1; vb += 1; } output[0] = sum0; output++; } } } } #endif