1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved. *
3 * This file is part of the LIBXSMM library. *
4 * *
5 * For information on the license, see the LICENSE file. *
6 * Further information: https://github.com/hfp/libxsmm/ *
7 * SPDX-License-Identifier: BSD-3-Clause *
8 ******************************************************************************/
9 /* Sasikanth Avancha, Dhiraj Kalamkar (Intel Corp.)
10 ******************************************************************************/
11
12
13 #include <stdio.h>
14 #include <omp.h>
15 #include <immintrin.h>
16 #include "SplitLoop.hpp"
17
18 # define _mm512_load_act(A) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepi16_epi32(_mm256_loadu_si256((__m256i*)(A))),16))
19 #if 1
20 __m512i vnaninf = _mm512_set1_epi32( 0x7f800000 );
21 __m512i vrneadd = _mm512_set1_epi32( 0x00007fff );
22 __m512i vfixup = _mm512_set1_epi32( 0x00000001 );
23 __m512i vfixupmask = _mm512_set1_epi32( 0x00010000 );
24 # define _mm512_roundbf16rne(A) _mm512_mask_add_epi32( _mm512_castps_si512( A ), _mm512_cmp_epi32_mask( _mm512_and_epi32( _mm512_castps_si512( A ), vnaninf ), vnaninf, _MM_CMPINT_NE ), _mm512_castps_si512( A ), _mm512_mask_add_epi32( vrneadd , _mm512_cmp_epi32_mask( _mm512_and_epi32( _mm512_castps_si512( A ), vfixupmask ), vfixupmask, _MM_CMPINT_EQ ), vrneadd, vfixup ) )
25 # define _mm512_stream_act(A,B) _mm256_stream_si256((__m256i*)A,_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_roundbf16rne((B)),16)))
26 # define _mm512_store_act(A,B) _mm256_storeu_si256((__m256i*)A,_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_roundbf16rne((B)),16)))
27 #else
28 # define _mm512_stream_act(A,B) _mm256_stream_si256((__m256i*)A,_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_castps_si512((B)),16)))
29 # define _mm512_store_act(A,B) _mm256_storeu_si256((__m256i*)A,_mm512_cvtepi32_epi16(_mm512_srai_epi32(_mm512_castps_si512((B)),16)))
30 #endif
31
32 #define VLEN 16
33
forwardPropagate(TensorBuf * inpb,vector<TensorBuf * > & outpb,int tid)34 void SplitLoop::forwardPropagate(TensorBuf *inpb, vector<TensorBuf*>& outpb, int tid)
35 {
36 for(int i=0; i<outpb.size(); i++)
37 {
38 outpb[i]->setBuffer(inpb->getBuffer());
39 outpb[i]->setBufferSize(inpb->getBufferSize());
40 outpb[i]->setLayoutType(inpb->getLayoutType());
41 }
42 }
43
backPropagate(vector<TensorBuf * > & deloutpb,TensorBuf * delinpb,int tid)44 void SplitLoop::backPropagate(vector<TensorBuf *>& deloutpb, TensorBuf *delinpb, int tid)
45 {
46 assert(gp->bdims == gp->tdims);
47
48 int nImg = gp->batch_size;
49 int nIfm = gp->nInput;
50 int ifh = gp->iHeight;
51 int ifw = gp->iWidth;
52
53 int in_dtype = delinpb->getDataType();
54 int out_dtype = deloutpb[0]->getDataType();
55
56 void* delinp = delinpb->getBuffer();
57
58 void *deloutp[deloutpb.size()];
59 int num_outp = 1;
60 int size = nImg*nIfm*ifh*ifw;
61
62 deloutp[0] = deloutpb[0]->getBuffer();
63
64 for(int i=1; i<deloutpb.size(); i++)
65 {
66 if(deloutpb[i] == NULL) continue;
67
68 deloutp[num_outp] = deloutpb[i]->getBuffer();
69 num_outp++;
70 }
71
72 if(in_dtype == DT_FLOAT && out_dtype == DT_FLOAT)
73 {
74 #ifdef __AVX512F__
75 if (size % 16 == 0) {
76 if ( num_outp == 2 ) {
77 float* out1 = (float*)deloutp[0];
78 float* out2 = (float*)deloutp[1];
79 #ifdef _OPENMP
80 #pragma omp parallel for
81 #endif
82 for(int j=0; j<size; j+=16) {
83 __m512 vo = _mm512_loadu_ps( out1+j );
84 vo = _mm512_add_ps( vo, _mm512_loadu_ps( out2+j ) );
85 #ifdef USE_NTS_SPLIT
86 _mm512_stream_ps( &(((float*)delinp)[j]), vo );
87 #else
88 _mm512_storeu_ps( &(((float*)delinp)[j]), vo );
89 #endif
90 }
91 } else if ( num_outp == 1 ) {
92 float* out1 = (float*)deloutp[0];
93 #ifdef _OPENMP
94 #pragma omp parallel for
95 #endif
96 for(int j=0; j<size; j+=16) {
97 __m512 vo = _mm512_loadu_ps( out1+j );
98 #ifdef USE_NTS_SPLIT
99 _mm512_stream_ps( &(((float*)delinp)[j]), vo );
100 #else
101 _mm512_storeu_ps( &(((float*)delinp)[j]), vo );
102 #endif
103 }
104 } else {
105 #ifdef _OPENMP
106 #pragma omp parallel for
107 #endif
108 for(int j=0; j<size; j+=16) {
109 __m512 vo = _mm512_loadu_ps( &(((float*)deloutp[0])[j]) );
110 for(int i=1; i<num_outp; i++) {
111 vo = _mm512_add_ps( vo, _mm512_loadu_ps( &(((float*)deloutp[i])[j]) ) );
112 }
113 #ifdef USE_NTS_SPLIT
114 _mm512_stream_ps( &(((float*)delinp)[j]), vo );
115 #else
116 _mm512_storeu_ps( &(((float*)delinp)[j]), vo );
117 #endif
118 }
119 }
120 } else {
121 #ifdef _OPENMP
122 #pragma omp parallel for
123 #endif
124 for(int j=0; j<size; j++) {
125 float o = ((float*)deloutp[0])[j];
126 for(int i=1; i<num_outp; i++) {
127 o += ((float*)deloutp[i])[j];
128 }
129 ((float*)delinp)[j] = o;
130 }
131 }
132 #else
133 #ifdef _OPENMP
134 #pragma omp parallel for
135 #endif
136 for(int j=0; j<size; j++) {
137 float o = ((float*)deloutp[0])[j];
138 for(int i=1; i<num_outp; i++) {
139 o += ((float*)deloutp[i])[j];
140 }
141 delinp[j] = o;
142 }
143 #endif
144 }
145 else if(in_dtype == DT_BF16 && out_dtype == DT_BF16)
146 {
147 #ifdef __AVX512F__
148 if (size % 16 == 0) {
149 if ( num_outp == 2 ) {
150 libxsmm_bfloat16* out1 = (libxsmm_bfloat16*)deloutp[0];
151 libxsmm_bfloat16* out2 = (libxsmm_bfloat16*)deloutp[1];
152 #ifdef _OPENMP
153 #pragma omp parallel for
154 #endif
155 for(int j=0; j<size; j+=16) {
156 __m512 vo = _mm512_load_act( out1+j );
157 vo = _mm512_add_ps( vo, _mm512_load_act( out2+j ) );
158 #ifdef USE_NTS_SPLIT
159 _mm512_stream_act( &(((libxsmm_bfloat16*)delinp)[j]), vo );
160 #else
161 _mm512_store_act( &(((libxsmm_bfloat16*)delinp)[j]), vo );
162 #endif
163 }
164 } else if ( num_outp == 1 ) {
165 libxsmm_bfloat16* out1 = (libxsmm_bfloat16*)deloutp[0];
166 #ifdef _OPENMP
167 #pragma omp parallel for
168 #endif
169 for(int j=0; j<size; j+=16) {
170 __m512 vo = _mm512_load_act( out1+j );
171 #ifdef USE_NTS_SPLIT
172 _mm512_stream_act( &(((libxsmm_bfloat16*)delinp)[j]), vo );
173 #else
174 _mm512_store_act( &(((libxsmm_bfloat16*)delinp)[j]), vo );
175 #endif
176 }
177 } else {
178 #ifdef _OPENMP
179 #pragma omp parallel for
180 #endif
181 for(int j=0; j<size; j+=16) {
182 __m512 vo = _mm512_load_act( &(((libxsmm_bfloat16*)deloutp[0])[j]) );
183 for(int i=1; i<num_outp; i++) {
184 vo = _mm512_add_ps( vo, _mm512_load_act( &(((libxsmm_bfloat16*)deloutp[i])[j]) ) );
185 }
186 #ifdef USE_NTS_SPLIT
187 _mm512_stream_act( &(((libxsmm_bfloat16*)delinp)[j]), vo );
188 #else
189 _mm512_store_act( &(((libxsmm_bfloat16*)delinp)[j]), vo );
190 #endif
191 }
192 }
193 } else {
194 #if defined(_OPENMP)
195 #pragma omp parallel
196 #endif
197 {
198 union libxsmm_bfloat16_hp deloutput_32_0, deloutput_32_1;
199
200 deloutput_32_0.i[0] = 0;
201 deloutput_32_0.i[1] = 0;
202 deloutput_32_1.i[0] = 0;
203 deloutput_32_1.i[1] = 0;
204
205 #if defined(_OPENMP)
206 #pragma omp for
207 #endif
208 for(int j=0; j<size; j++) {
209 deloutput_32_0.i[1] = ((libxsmm_bfloat16*)deloutp[0])[j];
210 for(int i=1; i<num_outp; i++) {
211 deloutput_32_1.i[1] = ((libxsmm_bfloat16*)deloutp[i])[j];
212 deloutput_32_0.f += deloutput_32_1.f;
213 }
214 ((libxsmm_bfloat16*)delinp)[j] = deloutput_32_0.i[1];
215 deloutput_32_0.i[0] = 0;
216 deloutput_32_0.i[1] = 0;
217 }
218 }
219 }
220 #else
221 #if defined(_OPENMP)
222 #pragma omp parallel
223 #endif
224 {
225 union libxsmm_bfloat16_hp deloutput_32_0, deloutput_32_1;
226
227 deloutput_32_0.i[0] = 0;
228 deloutput_32_0.i[1] = 0;
229 deloutput_32_1.i[0] = 0;
230 deloutput_32_1.i[1] = 0;
231
232 #if defined(_OPENMP)
233 #pragma omp for
234 #endif
235 for(int j=0; j<size; j++) {
236 deloutput_32_0.i[1] = ((libxsmm_bfloat16*)deloutp[0])[j];
237 for(int i=1; i<num_outp; i++) {
238 deloutput_32_1.i[1] = ((libxsmm_bfloat16*)deloutp[i])[j];
239 deloutput_32_0.f += deloutput_32_1.f;
240 }
241 ((libxsmm_bfloat16*)delinp)[j] = deloutput_32_0.i[1];
242 deloutput_32_0.i[0] = 0;
243 deloutput_32_0.i[1] = 0;
244 }
245 }
246 #endif
247 }
248
249 delinpb->setLayoutType(deloutpb[0]->getLayoutType());
250 }
251