1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved.                      *
3 * This file is part of the LIBXSMM library.                                   *
4 *                                                                             *
5 * For information on the license, see the LICENSE file.                       *
6 * Further information: https://github.com/hfp/libxsmm/                        *
7 * SPDX-License-Identifier: BSD-3-Clause                                       *
8 ******************************************************************************/
9 /* Sasikanth Avancha, Dhiraj Kalamkar (Intel Corp.)
10 ******************************************************************************/
11 
12 
13 #include <stdio.h>
14 #include <omp.h>
15 #include <immintrin.h>
16 #include "SplitLoop.hpp"
17 
18 # define _mm512_load_act(A)     _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepi16_epi32(_mm256_loadu_si256((__m256i*)(A))),16))
19 #if 1
20 __m512i vnaninf = _mm512_set1_epi32( 0x7f800000 );
21 __m512i vrneadd = _mm512_set1_epi32( 0x00007fff );
22 __m512i vfixup = _mm512_set1_epi32( 0x00000001 );
23 __m512i vfixupmask = _mm512_set1_epi32( 0x00010000 );
24 # define _mm512_roundbf16rne(A) _mm512_mask_add_epi32( _mm512_castps_si512( A ), _mm512_cmp_epi32_mask( _mm512_and_epi32( _mm512_castps_si512( A ), vnaninf ), vnaninf, _MM_CMPINT_NE ), _mm512_castps_si512( A ), _mm512_mask_add_epi32( vrneadd , _mm512_cmp_epi32_mask( _mm512_and_epi32( _mm512_castps_si512( A ), vfixupmask ), vfixupmask, _MM_CMPINT_EQ ), vrneadd, vfixup ) )
25 # define _mm512_stream_act(A,B) _mm256_stream_si256((__m256i*)A,_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_roundbf16rne((B)),16)))
26 # define _mm512_store_act(A,B)  _mm256_storeu_si256((__m256i*)A,_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_roundbf16rne((B)),16)))
27 #else
28 # define _mm512_stream_act(A,B) _mm256_stream_si256((__m256i*)A,_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_castps_si512((B)),16)))
29 # define _mm512_store_act(A,B)  _mm256_storeu_si256((__m256i*)A,_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_castps_si512((B)),16)))
30 #endif
31 
32 #define VLEN 16
33 
forwardPropagate(TensorBuf * inpb,vector<TensorBuf * > & outpb,int tid)34 void SplitLoop::forwardPropagate(TensorBuf *inpb, vector<TensorBuf*>& outpb, int tid)
35 {
36   for(int i=0; i<outpb.size(); i++)
37   {
38     outpb[i]->setBuffer(inpb->getBuffer());
39     outpb[i]->setBufferSize(inpb->getBufferSize());
40     outpb[i]->setLayoutType(inpb->getLayoutType());
41   }
42 }
43 
backPropagate(vector<TensorBuf * > & deloutpb,TensorBuf * delinpb,int tid)44 void SplitLoop::backPropagate(vector<TensorBuf *>& deloutpb, TensorBuf *delinpb, int tid)
45 {
46   assert(gp->bdims == gp->tdims);
47 
48   int nImg = gp->batch_size;
49   int nIfm = gp->nInput;
50   int ifh = gp->iHeight;
51   int ifw = gp->iWidth;
52 
53   int in_dtype = delinpb->getDataType();
54   int out_dtype = deloutpb[0]->getDataType();
55 
56   void* delinp = delinpb->getBuffer();
57 
58   void *deloutp[deloutpb.size()];
59   int num_outp = 1;
60   int size = nImg*nIfm*ifh*ifw;
61 
62   deloutp[0] = deloutpb[0]->getBuffer();
63 
64   for(int i=1; i<deloutpb.size(); i++)
65   {
66     if(deloutpb[i] == NULL) continue;
67 
68     deloutp[num_outp] = deloutpb[i]->getBuffer();
69     num_outp++;
70   }
71 
72   if(in_dtype == DT_FLOAT && out_dtype == DT_FLOAT)
73   {
74 #ifdef __AVX512F__
75     if (size % 16 == 0) {
76       if ( num_outp == 2 ) {
77         float* out1 = (float*)deloutp[0];
78         float* out2 = (float*)deloutp[1];
79 #ifdef _OPENMP
80 #pragma omp parallel for
81 #endif
82         for(int j=0; j<size; j+=16) {
83           __m512 vo = _mm512_loadu_ps( out1+j );
84           vo = _mm512_add_ps( vo, _mm512_loadu_ps( out2+j ) );
85 #ifdef USE_NTS_SPLIT
86           _mm512_stream_ps( &(((float*)delinp)[j]), vo );
87 #else
88           _mm512_storeu_ps( &(((float*)delinp)[j]), vo );
89 #endif
90         }
91       } else if ( num_outp == 1 ) {
92         float* out1 = (float*)deloutp[0];
93 #ifdef _OPENMP
94 #pragma omp parallel for
95 #endif
96         for(int j=0; j<size; j+=16) {
97           __m512 vo = _mm512_loadu_ps( out1+j );
98 #ifdef USE_NTS_SPLIT
99           _mm512_stream_ps( &(((float*)delinp)[j]), vo );
100 #else
101           _mm512_storeu_ps( &(((float*)delinp)[j]), vo );
102 #endif
103         }
104       } else {
105 #ifdef _OPENMP
106 #pragma omp parallel for
107 #endif
108         for(int j=0; j<size; j+=16) {
109           __m512 vo = _mm512_loadu_ps( &(((float*)deloutp[0])[j]) );
110           for(int i=1; i<num_outp; i++) {
111             vo = _mm512_add_ps( vo, _mm512_loadu_ps( &(((float*)deloutp[i])[j]) ) );
112           }
113 #ifdef USE_NTS_SPLIT
114           _mm512_stream_ps( &(((float*)delinp)[j]), vo );
115 #else
116           _mm512_storeu_ps( &(((float*)delinp)[j]), vo );
117 #endif
118         }
119       }
120     } else {
121 #ifdef _OPENMP
122 #pragma omp parallel for
123 #endif
124       for(int j=0; j<size; j++) {
125         float o = ((float*)deloutp[0])[j];
126         for(int i=1; i<num_outp; i++) {
127           o += ((float*)deloutp[i])[j];
128         }
129         ((float*)delinp)[j] = o;
130       }
131     }
132 #else
133 #ifdef _OPENMP
134 #pragma omp parallel for
135 #endif
136     for(int j=0; j<size; j++) {
137       float o = ((float*)deloutp[0])[j];
138       for(int i=1; i<num_outp; i++) {
139         o += ((float*)deloutp[i])[j];
140       }
141       delinp[j] = o;
142     }
143 #endif
144   }
145   else if(in_dtype == DT_BF16 && out_dtype == DT_BF16)
146   {
147 #ifdef __AVX512F__
148     if (size % 16 == 0) {
149       if ( num_outp == 2 ) {
150         libxsmm_bfloat16* out1 = (libxsmm_bfloat16*)deloutp[0];
151         libxsmm_bfloat16* out2 = (libxsmm_bfloat16*)deloutp[1];
152 #ifdef _OPENMP
153 #pragma omp parallel for
154 #endif
155         for(int j=0; j<size; j+=16) {
156           __m512 vo = _mm512_load_act( out1+j );
157           vo = _mm512_add_ps( vo, _mm512_load_act( out2+j ) );
158 #ifdef USE_NTS_SPLIT
159           _mm512_stream_act( &(((libxsmm_bfloat16*)delinp)[j]), vo );
160 #else
161           _mm512_store_act( &(((libxsmm_bfloat16*)delinp)[j]), vo );
162 #endif
163         }
164       } else if ( num_outp == 1 ) {
165         libxsmm_bfloat16* out1 = (libxsmm_bfloat16*)deloutp[0];
166 #ifdef _OPENMP
167 #pragma omp parallel for
168 #endif
169         for(int j=0; j<size; j+=16) {
170           __m512 vo = _mm512_load_act( out1+j );
171 #ifdef USE_NTS_SPLIT
172           _mm512_stream_act( &(((libxsmm_bfloat16*)delinp)[j]), vo );
173 #else
174           _mm512_store_act( &(((libxsmm_bfloat16*)delinp)[j]), vo );
175 #endif
176         }
177       } else {
178 #ifdef _OPENMP
179 #pragma omp parallel for
180 #endif
181         for(int j=0; j<size; j+=16) {
182           __m512 vo = _mm512_load_act( &(((libxsmm_bfloat16*)deloutp[0])[j]) );
183           for(int i=1; i<num_outp; i++) {
184             vo = _mm512_add_ps( vo, _mm512_load_act( &(((libxsmm_bfloat16*)deloutp[i])[j]) ) );
185           }
186 #ifdef USE_NTS_SPLIT
187           _mm512_stream_act( &(((libxsmm_bfloat16*)delinp)[j]), vo );
188 #else
189           _mm512_store_act( &(((libxsmm_bfloat16*)delinp)[j]), vo );
190 #endif
191         }
192       }
193     } else {
194 #if defined(_OPENMP)
195 #pragma omp parallel
196 #endif
197       {
198         union libxsmm_bfloat16_hp deloutput_32_0, deloutput_32_1;
199 
200         deloutput_32_0.i[0] = 0;
201         deloutput_32_0.i[1] = 0;
202         deloutput_32_1.i[0] = 0;
203         deloutput_32_1.i[1] = 0;
204 
205 #if defined(_OPENMP)
206 #pragma omp for
207 #endif
208         for(int j=0; j<size; j++) {
209           deloutput_32_0.i[1] = ((libxsmm_bfloat16*)deloutp[0])[j];
210           for(int i=1; i<num_outp; i++) {
211             deloutput_32_1.i[1] = ((libxsmm_bfloat16*)deloutp[i])[j];
212             deloutput_32_0.f += deloutput_32_1.f;
213           }
214           ((libxsmm_bfloat16*)delinp)[j] = deloutput_32_0.i[1];
215           deloutput_32_0.i[0] = 0;
216           deloutput_32_0.i[1] = 0;
217         }
218       }
219     }
220 #else
221 #if defined(_OPENMP)
222 #pragma omp parallel
223 #endif
224     {
225       union libxsmm_bfloat16_hp deloutput_32_0, deloutput_32_1;
226 
227       deloutput_32_0.i[0] = 0;
228       deloutput_32_0.i[1] = 0;
229       deloutput_32_1.i[0] = 0;
230       deloutput_32_1.i[1] = 0;
231 
232 #if defined(_OPENMP)
233 #pragma omp for
234 #endif
235       for(int j=0; j<size; j++) {
236         deloutput_32_0.i[1] = ((libxsmm_bfloat16*)deloutp[0])[j];
237         for(int i=1; i<num_outp; i++) {
238           deloutput_32_1.i[1] = ((libxsmm_bfloat16*)deloutp[i])[j];
239           deloutput_32_0.f += deloutput_32_1.f;
240         }
241         ((libxsmm_bfloat16*)delinp)[j] = deloutput_32_0.i[1];
242         deloutput_32_0.i[0] = 0;
243         deloutput_32_0.i[1] = 0;
244       }
245     }
246 #endif
247   }
248 
249   delinpb->setLayoutType(deloutpb[0]->getLayoutType());
250 }
251