1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved.                      *
3 * This file is part of the LIBXSMM library.                                   *
4 *                                                                             *
5 * For information on the license, see the LICENSE file.                       *
6 * Further information: https://github.com/hfp/libxsmm/                        *
7 * SPDX-License-Identifier: BSD-3-Clause                                       *
8 ******************************************************************************/
9 /* Sasikanth Avancha, Dhiraj Kalamkar, Alexander Heinecke (Intel Corp.)
10 ******************************************************************************/
11 
12 #include <string>
13 #include "FusedConvBN.hpp"
14 #include "fillers.hpp"
15 
16 #ifdef USE_MLSL
17 #include "mpi.h"
18 #endif
19 
20 
21 using namespace std;
22 using namespace gxm;
23 
FusedConvBNNode(FusedConvBNParams * p,MLEngine * e)24 FusedConvBNNode::FusedConvBNNode(FusedConvBNParams* p, MLEngine* e): NNNode(p, e)
25 {
26   nname_ = p->get_node_name();
27   ntype_ = p->get_node_type();
28   bottom_ = p->get_bottom_names();
29   top_ = p->get_top_names();
30   bp_flag_ = p->get_bprop_flag();
31   has_weights_ = true;
32   bot_compute_engine_ = p->get_compute_engine();
33 
34   tenTop_ = new Tensor(top_[0]);
35   assert(tenTop_ != NULL);
36   tenTop_->setOwner(this);
37   tenTop_->setType(ACT);
38   tenTopData_ = tenTop_->getBuf(DATA);
39   tenTopData_->setBufferType(DATA);
40 
41   tenMid_ = new Tensor("mid_"+top_[0]);
42   assert(tenMid_ != NULL);
43   tenMid_->setOwner(this);
44   tenMid_->setType(ACT);
45   tenMidData_ = tenMid_->getBuf(DATA);
46   tenMidData_->setBufferType(DATA);
47 
48   tenBot_.resize(bottom_.size());
49   tenBotData_.resize(bottom_.size());
50 
51   for(int i=0; i < bottom_.size(); i++)
52   {
53 #ifndef NDEBUG
54     printf("bottom%d name %s\n",i,bottom_[i].c_str());
55 #endif
56 
57     if(bottom_[i] == "data")
58       tenBot_[i] = e->get_tensor(bottom_[i], INPUT);
59     else
60       tenBot_[i] = e->get_tensor(bottom_[i], ACT);
61 
62     assert(tenBot_[i] != NULL);
63     NNNode *pnn = (NNNode*)tenBot_[i]->getOwner();
64     setPrevNode(pnn);
65     mode_ = pnn->getMode();
66     pnn->set_top_compute_engine(p->get_compute_engine());
67     bot_cengine_ = pnn->get_bot_compute_engine();
68 
69     tenBotData_[i] = tenBot_[i]->getBuf(DATA);
70   }
71 
72   in_dtype = tenBotData_[0]->getDataType();
73   out_dtype = p->get_data_type();
74   tenTopData_->setDataType(out_dtype);
75 
76   // Get input tensor shape (bottom)
77   Shape* bs = tenBot_[0]->getShape();
78   assert(bs->ndims <= MAX_DIMS);
79 
80   // Create shape of output tensor (top)
81   vector<int> vd = p->get_kernel_dims();
82   vector<int> mvp = p->get_mid_pads();
83   vector<int> ovp = p->get_top_pads();
84   vector<int> ivp = p->get_bot_pads();
85   vector<int> vcs = p->get_c_strides();
86   vector<int> vbns = p->get_bn_strides();
87 
88   shape_setzero(&ms_);
89   ms_.ndims = bs->ndims; // Number of dimensions
90   ms_.dims[0] = bs->dims[0]; // Minibatch size
91   ms_.dims[1] = p->get_output(); // Num output feature maps
92   ms_.dims[2] = (bs->dims[2] - vd[0] + 2*ivp[0])/vcs[0] + 1; // Height
93   ms_.dims[3] = (bs->dims[3] - vd[1] + 2*ivp[1])/vcs[1] + 1; // Width
94 
95   tenMid_->setShape(&ms_);
96 
97   shape_setzero(&ts_);
98   ts_.ndims = bs->ndims; // Number of dimensions
99   ts_.dims[0] = bs->dims[0]; // Minibatch size
100   ts_.dims[1] = p->get_output(); // Num output feature maps
101   ts_.dims[2] = ms_.dims[2]/vbns[0]; // Height
102   ts_.dims[3] = ms_.dims[3]/vbns[1]; // Width
103 
104   tenTop_->setShape(&ts_);
105 
106   long long int tsize;
107   int convelem = ms_.dims[0] * ms_.dims[1] * (ms_.dims[2] + 2*mvp[0]) * (ms_.dims[3] + 2*mvp[1]);
108   int bnelem = ts_.dims[0] * ts_.dims[1] * (ts_.dims[2] + 2*ovp[0]) * (ts_.dims[3] + 2*ovp[1]);
109   int telem = convelem + bnelem;
110 
111   if(out_dtype == DT_FLOAT)
112     tsize = telem*sizeof(float);
113   else if(out_dtype = DT_BF16)
114     tsize = telem*sizeof(libxsmm_bfloat16);
115 
116   tenTopData_->setBufferSize(tsize);
117 
118   // Create FP weight tensor
119   weight_ = top_[0] + "_wt";
120   tenWeight_ = new Tensor(weight_);
121   assert(tenWeight_ != NULL);
122   tenWeight_->setOwner(this);
123   tenWeight_->setType(CONVWEIGHT);
124 
125   shape_setzero(&ws_);
126 
127   ws_.ndims = ts_.ndims;      // Number of dimesions
128   ws_.dims[0] = ms_.dims[1];  // Num output feature maps (from mid tensor)
129   ws_.dims[1] = bs->dims[1];  // Num input feature maps (from bottom tensor)
130   ws_.dims[2] = vd[0];        // Kernel height
131   ws_.dims[3] = vd[1]; // Kernel width
132 
133   tenWeight_->setShape(&ws_);
134   tenWeight_->setBufDataType(DATA, DT_FLOAT);
135   tenWeightData_ = tenWeight_->getBuf(DATA);
136   tenWeightData_->setBufferType(DATA);
137 
138   int welem = 1;
139   long long int wsize;
140   for(int i=0; i<ws_.ndims; i++)
141     welem = welem*ws_.dims[i];
142 
143   // size of master weights -- FP32.
144   wsize = welem*sizeof(float);
145 
146   gparams_.num_numa_nodes = NUM_NUMA_NODES;
147   tenWeightData_->setBufferSize(wsize);
148 
149   wfiller_type_ = p->get_weight_filler_type();
150   variance_norm_ = p->get_variance_norm();
151   std_ = p->get_std();
152 
153   lr_mult_ = p->get_lr_mult();
154   decay_mult_ = p->get_decay_mult();
155 
156   Shape sss;
157   shape_setzero(&sss);
158   sss.ndims = 1;
159   sss.dims[0] = ts_.dims[1];
160 
161   scale_ = top_[0] + "_scale";
162   tenScale_ = new Tensor(scale_);
163   assert(tenScale_ != NULL);
164   tenScale_->setOwner(this);
165   tenScale_->setType(BNORMSCALE);
166   tenScale_->setShape(&sss);
167   tenScaleData_ = tenScale_->getBuf(DATA);
168   tenScaleData_->setDataType(DT_FLOAT);
169   tenScaleData_->setBufferType(DATA);
170 
171   telem = sss.dims[0];
172   tsize = telem*sizeof(float);
173   tenScaleData_->setBufferSize(tsize);
174 
175   shift_ = top_[0] + "_shift";
176   tenShift_ = new Tensor(shift_);
177   assert(tenShift_ != NULL);
178   tenShift_->setOwner(this);
179   tenShift_->setType(BNORMSHIFT);
180   tenShift_->setShape(&sss);
181   tenShiftData_ = tenShift_->getBuf(DATA);
182   tenShiftData_->setDataType(DT_FLOAT);
183   tenShiftData_->setBufferType(DATA);
184 
185   tenShiftData_->setBufferSize(tsize);
186 
187   mean_ = top_[0] + "_mean";
188   tenMean_ = new Tensor(mean_);
189   assert(tenMean_ != NULL);
190   tenMean_->setOwner(this);
191   tenMean_->setType(BNORMMEAN);
192   tenMean_->setShape(&sss);
193   tenMeanData_ = tenMean_->getBuf(DATA);
194   tenMeanData_->setDataType(DT_FLOAT);
195   tenMeanData_->setBufferType(DATA);
196   tenMeanData_->setBufferSize(tsize);
197 
198   var_ = top_[0] + "_var";
199   tenVar_ = new Tensor(var_);
200   assert(tenVar_ != NULL);
201   tenVar_->setOwner(this);
202   tenVar_->setType(BNORMVAR);
203   tenVar_->setShape(&sss);
204   tenVarData_ = tenVar_->getBuf(DATA);
205   tenVarData_->setDataType(DT_FLOAT);
206   tenVarData_->setBufferType(DATA);
207   tenVarData_->setBufferSize(tsize);
208 
209   if(!e->is_inference_only()) {
210     if(bp_flag_)
211     {
212       tenBotDiff_.resize(bottom_.size());
213       for(int i=0; i<bottom_.size(); i++)
214       {
215         tenBotDiff_[i] = tenBot_[i]->addBuf(); // DIFF type and index
216         tenBotDiff_[i]->setDataType(in_dtype);
217         tenBotDiff_[i]->setBufferType(DIFF);
218 
219         // Set the size of the input-gradient buffer
220         Shape *bs = tenBot_[i]->getShape();
221         int botelem = bs->dims[0] * bs->dims[1] * (bs->dims[2] + 2*ivp[0]) * (bs->dims[3] + 2*ivp[1]);
222         if(in_dtype == DT_FLOAT)
223           tenBotDiff_[i]->setBufferSize((botelem + convelem)*sizeof(float));
224         else if(in_dtype == DT_BF16)
225           tenBotDiff_[i]->setBufferSize((botelem + convelem)*sizeof(libxsmm_bfloat16));
226       }
227       tenMidDiff_ = tenMid_->addBuf(); // DIFF type and index
228       tenMidDiff_->setDataType(in_dtype);
229       tenMidDiff_->setBufferType(DIFF);
230     }
231 
232     if(has_weights_)
233     {
234       if(tenMidDiff_ == NULL)
235       {
236         tenMidDiff_ = tenMid_->addBuf(); // DIFF type and index
237         tenMidDiff_->setDataType(in_dtype);
238         tenMidDiff_->setBufferType(DIFF);
239         if(in_dtype == DT_FLOAT)
240           tenMidDiff_->setBufferSize(convelem*sizeof(float));
241         else if(in_dtype == DT_BF16)
242           tenMidDiff_->setBufferSize(convelem*sizeof(libxsmm_bfloat16));
243       }
244 
245       tenWeightDiff_ = tenWeight_->addBuf(); // DIFF type and index
246       tenWeightDiff_->setBufferType(DIFF);
247 
248       tenWeightInc_ = tenWeight_->addBuf(); // SHARED type and index
249       tenWeightInc_->setDataType(DT_FLOAT);
250       tenWeightInc_->setBufferType(HISTORY);
251       tenWeightInc_->setBufferSize(welem*sizeof(float));
252 
253       // Set the size of the weight-gradient buffer and the weight-increment buffer
254       if(in_dtype == DT_FLOAT)
255       {
256         tenWeightDiff_->setDataType(DT_FLOAT);
257         tenWeightDiff_->setBufferSize(welem*sizeof(float));
258       }
259       else if(in_dtype == DT_BF16)
260       {
261         tenWeightDiff_->setDataType(DT_BF16);
262         tenWeightDiff_->setBufferSize(welem*sizeof(libxsmm_bfloat16));
263       }
264 
265       tenScaleDiff_ = tenScale_->addBuf();
266       tenScaleDiff_->setDataType(DT_FLOAT);
267       tenScaleDiff_->setBufferType(DIFF);
268       tenScaleDiff_->setBufferSize(tsize);
269 
270       tenScaleInc_ = tenScale_->addBuf();
271       tenScaleInc_->setDataType(DT_FLOAT);
272       tenScaleInc_->setBufferType(HISTORY);
273       tenScaleInc_->setBufferSize(tsize);
274 
275       tenShiftDiff_ = tenShift_->addBuf();
276       tenShiftDiff_->setDataType(DT_FLOAT);
277       tenShiftDiff_->setBufferType(DIFF);
278       tenShiftDiff_->setBufferSize(tsize);
279 
280       tenShiftInc_ = tenShift_->addBuf();
281       tenShiftInc_->setDataType(DT_FLOAT);
282       tenShiftInc_->setBufferType(HISTORY);
283       tenShiftInc_->setBufferSize(tsize);
284     }
285   }
286   else {
287     tenMidDiff_ = NULL;
288     tenWeightDiff_ = NULL;
289     tenWeightInc_ = NULL;
290     tenScaleDiff_ = NULL;
291     tenShiftDiff_ = NULL;
292     tenScaleInc_ = NULL;
293     tenShiftInc_ = NULL;
294   }
295 
296   // Register output tensor in tensor map
297   bool inserted = e->register_tensor(top_[0], ACT, tenTop_);
298   if(!inserted)
299     printf("Warning: Tensor %s already registered\n",top_[0].c_str());
300 
301   string m = "mid_"+top_[0];
302   inserted = e->register_tensor(m, ACT, tenMid_);
303   if(!inserted)
304     printf("Warning: Tensor %s already registered\n",m.c_str());
305 
306   // Register weight tensor in weight tensor map
307   inserted = e->register_tensor(weight_, CONVWEIGHT, tenWeight_);
308   if(!inserted)
309     printf("Warning: Tensor %s already registered\n",weight_.c_str());
310 
311   inserted = e->register_tensor(scale_, BNORMSCALE, tenScale_);
312   if(!inserted)
313     printf("Warning: Tensor %s already registered\n",scale_.c_str());
314 
315   inserted = e->register_tensor(shift_, BNORMSHIFT, tenShift_);
316   if(!inserted)
317     printf("Warning: Tensor %s already registered\n",shift_.c_str());
318 
319   inserted = e->register_tensor(mean_, BNORMMEAN, tenMean_);
320   if(!inserted)
321     printf("Warning: Tensor %s already registered\n",mean_.c_str());
322 
323   inserted = e->register_tensor(var_, BNORMVAR, tenVar_);
324   if(!inserted)
325     printf("Warning: Tensor %s already registered\n",var_.c_str());
326 
327   // Setup parameter structure for convolution computation in library
328   gparams_.bdims = bs->ndims;
329   gparams_.tdims = ts_.ndims;
330   gparams_.mdims = ms_.ndims;
331   gparams_.wdims = ws_.ndims;
332 
333   gparams_.node_name = nname_;
334   gparams_.node_type = ntype_;
335   gparams_.nInput.resize(bottom_.size());
336   if(bottom_.size() > 1)
337     gparams_.nInput.resize(bottom_.size());
338   gparams_.nInput[0] = bs->dims[1];
339   if(bottom_.size() > 1)
340     gparams_.nInput[1] = tenBot_[1]->getShape()->dims[1];
341   gparams_.nOutput = ts_.dims[1];
342   gparams_.batch_size = bs->dims[0];
343   gparams_.iHeight = bs->dims[2];
344   gparams_.iWidth = bs->dims[3];
345   gparams_.mHeight = ms_.dims[2];
346   gparams_.mWidth = ms_.dims[3];
347   gparams_.oHeight = ts_.dims[2];
348   gparams_.oWidth = ts_.dims[3];
349   gparams_.ipad_h = ivp[0];
350   gparams_.ipad_w = ivp[1];
351   gparams_.mpad_h = mvp[0];
352   gparams_.mpad_w = mvp[1];
353   gparams_.opad_h = ovp[0];
354   gparams_.opad_w = ovp[1];
355   gparams_.physical_padding = p->get_physical_padding();
356 
357   gparams_.group = p->get_group();
358   gparams_.c_stride_h = vcs[0];
359   gparams_.c_stride_w = vcs[1];
360   gparams_.bn_stride_h = vbns[0];
361   gparams_.bn_stride_w = vbns[1];
362   gparams_.kh = ws_.dims[2];
363   gparams_.kw = ws_.dims[3];
364 
365   gparams_.relu_fwd = p->get_relu_fwd();
366   gparams_.relu_bwd = p->get_relu_bwd();
367 
368   gparams_.mmf = p->get_mmf();
369   gparams_.eps = p->get_eps();
370   gparams_.use_global_stats = p->get_global_stats_flag();
371   gparams_.eltwise = p->get_eltwise();
372   gparams_.bprop = bp_flag_;
373 
374   gparams_.in_data_type = in_dtype;
375   gparams_.out_data_type = out_dtype;
376   gparams_.algType = p->get_algo_type();
377   gparams_.num_threads = e->get_num_threads();
378 
379   // get solver
380   solver_ = e->getSolver();
381 
382   //get global scratch tensor buffer
383   tenScratchData_ = e->getScratchBuffer();
384 
385   // get engine
386   eptr_ = e;
387 
388 #ifdef USE_MLSL
389   MLSL::DataType dt = MLSL::DT_FLOAT;
390   MLSL::OperationRegInfo *myRegInfo;
391   MLSL::Session *s = eptr_->get_session();
392   myRegInfo = s->CreateOperationRegInfo(MLSL::OT_CC);
393   myRegInfo->SetName(nname_.c_str());
394   myRegInfo->AddParameterSet(gparams_.nInput[0]*gparams_.nOutput/gparams_.group, gparams_.kw*gparams_.kh, dt, false);
395   myRegInfo->AddParameterSet(gparams_.nOutput, 1, dt, false);
396   myRegInfo->AddParameterSet(gparams_.nOutput, 1, dt, false);
397   myRegInfo->AddParameterSet(gparams_.nOutput, 1, dt, false);
398   myRegInfo->AddParameterSet(gparams_.nOutput, 1, dt, false);
399 
400   myRegInfo->Validate();
401   size_t opIdx = s->AddOperation(myRegInfo, e->get_distribution());
402   this->op_ = s->GetOperation(opIdx);
403   s->DeleteOperationRegInfo(myRegInfo);
404   e->get_combo_grad_comms_vec().push_back(op_);
405 #endif
406 
407   configure(p->get_compute_engine());
408 }
409 
configure(int engine)410 void FusedConvBNNode::configure(int engine)
411 {
412   switch(engine)
413   {
414     case XSMM:
415       impl = new FusedConvBNXSMM(&gparams_, engine);
416       break;
417   }
418 }
419 
fillWeightBuffers(TensorBuf * tBuf,int buftype,long long int size)420 void FusedConvBNNode::fillWeightBuffers(TensorBuf* tBuf, int buftype, long long int size)
421 {
422   int dtype = DT_FLOAT;
423   void *ptr = tBuf->getBuffer();
424 
425 #ifdef USE_MLSL
426   unsigned int node_id = MLSL::Environment::GetEnv().GetProcessIdx();
427 #else
428   unsigned int node_id = 0;
429 #endif
430 
431   int ic = gparams_.nInput[0];
432   int oc = gparams_.nOutput;
433   int kh = gparams_.kh;
434   int kw = gparams_.kw;
435   int g = gparams_.group;
436   int fanin = (ic * kh * kw)/g;
437   int fanout = (oc * kh * kw)/g;
438   int welem = ic * oc * kh * kw;
439 
440   if(buftype == DATA)
441   {
442     if(node_id == 0)
443       initBuffer(ptr, variance_norm_, fanin, fanout, welem*sizeof(float), wfiller_type_, std_);
444 
445 #ifdef USE_MLSL
446     if(dtype == DT_FLOAT)
447       MPI_Bcast(ptr, welem, MPI_FLOAT, 0, MPI_COMM_WORLD);
448 #endif
449   }
450   else if(buftype == HISTORY || buftype == DIFF)
451     memset(ptr, 0, size);
452 }
453 
fillWeightMultipliers(float * lr,float * decay,long long int size)454 void FusedConvBNNode::fillWeightMultipliers(float* lr, float* decay, long long int size)
455 {
456   for(int i=0; i < size; i++)
457   {
458     lr[i] = lr_mult_[0];
459     decay[i] = decay_mult_[0];
460   }
461 }
462 
fillBiasMultipliers(float * lr,float * decay,long long int size)463 void FusedConvBNNode::fillBiasMultipliers(float* lr, float* decay, long long int size)
464 {
465   for(int i=0; i < size; i++)
466   {
467     lr[i] = lr_mult_[1];
468     decay[i] = decay_mult_[1];
469   }
470 }
471 
fillBuffer(TensorBuf * tBuf,int buftype,long long int size)472 void FusedConvBNNode::fillBuffer(TensorBuf* tBuf, int buftype, long long int size)
473 {
474   int ttype = tBuf->getTensor()->getType();
475   int dtype = DT_FLOAT;
476   void *ptr = tBuf->getBuffer();
477 
478   if(ttype==BNORMSCALE && buftype == DATA)
479   {
480     if(nname_.find("bn3") == nname_.npos)
481       initConstantBuffer(ptr, size, "CONSTANT", 1.0f);
482     else
483       initConstantBuffer(ptr, size, "CONSTANT", 0.0f);
484   }
485   else
486       initConstantBuffer(ptr, size, "CONSTANT", 0.0f);
487 }
488 
Checkpoint(TensorBuf * tBuf,string name,string format)489 void FusedConvBNNode::Checkpoint(TensorBuf *tBuf, string name, string format)
490 {
491   long long int bytes = tBuf->getBufferSize();
492   int dtype = tBuf->getDataType();
493 
494   FILE* f;
495   void* ptr;
496   size_t pos;
497 
498   if((name.find("30") == name.npos) && (name.find("60") == name.npos) && (name.find("80") == name.npos))
499     while((pos = name.find("/", 10)) != name.npos)
500       name.replace(pos, 1, 1, '_');
501 
502   float* p = (float*)tBuf->getBuffer();
503   bool no_checkpt = false;
504   for(int i=0; i<16; i++)
505   {
506     if(isnan(p[i]) || isinf(p[i]))
507     {
508       no_checkpt = true;
509       printf("Warning! %s Did not checkpoint! Weights are NaNs or Inf\n", nname_.c_str());
510       break;
511     }
512   }
513 
514   if(!no_checkpt)
515   {
516     if(format.compare("binary") == 0)
517     {
518       f = fopen(name.c_str(), "wb");
519       if(f != NULL)
520       {
521         if(name.find("wt") != name.npos)
522         {
523           ptr = _mm_malloc(bytes, 64);
524           assert(ptr != NULL);
525           impl->dumpBuffer(tBuf, ptr);
526         }
527         else if(name.find("mean") != name.npos || name.find("var") != name.npos)
528           ptr = tBuf->getPrivBuffer();
529         else
530           ptr = tBuf->getBuffer();
531 
532         size_t b = fwrite(ptr, 1, bytes, f);
533         assert((long long int)b == bytes);
534 
535         if(name.find("wt") != name.npos)
536           _mm_free(ptr);
537       }
538       else
539         printf("Warning: could not checkpoint to file %s\n",name.c_str());
540     }
541     else
542     {
543       f = fopen(name.c_str(), "w");
544       if(f != NULL)
545       {
546         if(name.find("wt") != name.npos)
547         {
548           ptr = _mm_malloc(bytes, 64);
549           assert(ptr != NULL);
550           impl->dumpBuffer(tBuf, ptr);
551         }
552         else
553           ptr = tBuf->getBuffer();
554 
555         for(int i=0; i<bytes/sizeof(float); i++)
556           fprintf(f, "%f\n", *((float*)ptr + i));
557 
558         if(name.find("wt") != name.npos)
559           _mm_free(ptr);
560       }
561       else
562         printf("Warning: could not checkpoint to file %s\n",name.c_str());
563     }
564     if(f != NULL)
565     {
566       fflush(f);
567       fclose(f);
568     }
569   }
570 }
571 
convert_f32_bf16(float * in,libxsmm_bfloat16 * out,int len)572 void FusedConvBNNode::convert_f32_bf16(float* in, libxsmm_bfloat16* out, int len)
573 {
574   int i;
575 
576 #ifdef _OPENMP
577 #pragma omp parallel for private(i)
578 #endif
579   for ( i = 0; i < len; i+=16 ) {
580     __m512  vfp32  = gxm_fp32_to_bfp16_rne_adjustment_avx512f(_mm512_loadu_ps(in + i));
581     __m256i vbfp16 = gxm_fp32_to_bfp16_truncate_avx512f(vfp32);
582     _mm256_storeu_si256( (__m256i*)(out+i), vbfp16 );
583   }
584 }
585 
convert_bf16_f32(libxsmm_bfloat16 * in,float * out,int len)586 void FusedConvBNNode::convert_bf16_f32(libxsmm_bfloat16* in, float* out, int len)
587 {
588   int i;
589 
590 #ifdef _OPENMP
591 #pragma omp parallel for private(i)
592 #endif
593   for ( i = 0; i < len; i+=16 ) {
594     __m256i vbfp16    = _mm256_loadu_si256( (const __m256i*)(in+i) );
595     __m512  vfp32     = gxm_bfp16_to_fp32_avx512f( vbfp16 );
596     _mm512_storeu_ps( out+i, vfp32 );
597   }
598 }
599 
forwardPropagate()600 void FusedConvBNNode::forwardPropagate()
601 {
602   int nImg = gparams_.batch_size;
603   int ifm0 = gparams_.nInput[0];
604   int ifm1 = gparams_.eltwise ? gparams_.nInput[1] : 0;
605   int ofm = gparams_.nOutput;
606   int ifh = gparams_.iHeight;
607   int ifhp = ifh + 2*gparams_.ipad_h;
608   int ifw = gparams_.iWidth;
609   int ifwp = ifw + 2*gparams_.ipad_w;
610   int mfh = gparams_.mHeight;
611   int mfw = gparams_.mWidth;
612   int mfhp = mfh + 2*gparams_.mpad_h;
613   int mfwp = mfw + 2*gparams_.mpad_w;
614   int ofh = gparams_.oHeight;
615   int ofw = gparams_.oWidth;
616   int oph = gparams_.opad_h;
617   int opw = gparams_.opad_w;
618   int ofhp = ofh + 2*oph;
619   int ofwp = ofw + 2*opw;
620   int bnsh = gparams_.bn_stride_h;
621   int bnsw = gparams_.bn_stride_w;
622   int kh = gparams_.kh;
623   int kw = gparams_.kw;
624 
625 #ifndef NDEBUG
626   // printf("Executing FP %s: input %p, weights %p, output %p\n",NNNode::nname_.c_str(), bot, wt, top);
627   printf("Executing FP %s\n",NNNode::nname_.c_str());
628   printf("Inputs: %d x %d x %d\n",ifm0, ifh, ifw);
629   printf("Outputs: %d x %d x %d\n",ofm, ofh, ofw);
630   printf("Weights: %d x %d x %d x %d\n", ifm0, ofm, kh, kw);
631   printf("Bias: %d\n", ofm);
632 #endif
633 
634   if(first_fp)
635   {
636     impl->set_top_compute_engine(top_compute_engine_);
637     impl->set_bot_compute_engine(bot_cengine_);
638     impl->set_node_name(nname_);
639     impl->set_scratch_buffer(tenScratchData_);
640 
641     if(eptr_->get_execution_mode() == TRAIN || eptr_->get_execution_mode() == VAL)
642     {
643       impl->set_global_stats(false);
644       gparams_.exec_mode = "TRAIN";
645     }
646     else if(eptr_->get_execution_mode() == TEST)
647       impl->set_global_stats(true);
648 
649     tenMidData_->setBuffer(tenTopData_->getBuffer());
650 
651     if(out_dtype == DT_FLOAT)
652     {
653       float* ptr = (float*)tenMidData_->getBuffer();
654       int size = nImg * ofm * mfhp * mfwp;
655       tenMidData_->setBufferSize(size*sizeof(float));
656       tenTopData_->setBuffer(tenTopData_->getBuffer() + size*sizeof(float));
657       tenTopData_->setBufferSize(tenTopData_->getBufferSize() - size*sizeof(float));
658 
659       // NUMA initialize Conv output
660 #ifdef _OPENMP
661 #pragma omp parallel for
662 #endif
663       for(int i=0; i<size; i++)
664         ptr[i] = 0;
665 
666       // NUMA initialize BN output
667       size = nImg * ofm * (ofh/bnsh +2*oph) * (ofw/bnsw + 2*opw);
668       ptr = (float*)tenTopData_->getBuffer();
669 
670 #ifdef _OPENMP
671 #pragma omp parallel for
672 #endif
673       for(int i=0; i<size; i++)
674         ptr[i] = 0;
675     }
676     else if(out_dtype == DT_BF16)
677     {
678       libxsmm_bfloat16* ptr = (libxsmm_bfloat16*)tenMidData_->getBuffer();
679       int size = nImg * ofm * mfhp * mfwp;
680       tenMidData_->setBufferSize(size*sizeof(libxsmm_bfloat16));
681       tenTopData_->setBuffer(tenTopData_->getBuffer() + size*sizeof(libxsmm_bfloat16));
682       tenTopData_->setBufferSize(tenTopData_->getBufferSize() - size*sizeof(libxsmm_bfloat16));
683 
684       // NUMA initialize Conv output
685 #ifdef _OPENMP
686 #pragma omp parallel for
687 #endif
688       for(int i=0; i<size; i++)
689         ptr[i] = 0;
690 
691       // NUMA initialize BN output
692       ptr = (libxsmm_bfloat16*)tenTopData_->getBuffer();
693       size = nImg * ofm * (ofh/bnsh + 2*oph) * (ofw/bnsw + 2*opw);
694 
695 #ifdef _OPENMP
696 #pragma omp parallel for
697 #endif
698       for(int i=0; i<size; i++)
699         ptr[i] = 0;
700     }
701 
702     cbptr = (float*)_mm_malloc(10240*4, 64);
703     scf_ = eptr_->get_scaling_factor();
704     impl->set_scaling_factor(scf_);
705 
706     first_fp = false;
707   }
708 
709   impl->forwardPropagate(tenBotData_, tenWeightData_, tenWeightInc_, tenMidData_, tenScaleData_, tenShiftData_, tenMeanData_, tenVarData_, tenTopData_, 0);
710 
711   if(eptr_->get_execution_mode() != TEST && eptr_->get_execution_mode() != VAL)
712   {
713     scf_ *= gparams_.mmf;
714     scf_ += 1.;
715 
716     eptr_->set_scaling_factor(scf_);
717   }
718 
719 #ifdef CHECK_BLOWUP_FP32
720   if(out_dtype == DT_FLOAT)
721   {
722     for(int i=0; i<10240; i++)
723     {
724       float v = ((float*)tenTopData_->getBuffer())[i];
725       if(isnan(v) || isinf(v))
726       {
727         printf("Warning! %s layer FP activations are NaN or Inf\n", nname_.c_str());
728         exit(-1);
729       }
730     }
731   }
732   else if(out_dtype == DT_BF16)
733   {
734     convert_bf16_f32((libxsmm_bfloat16*)tenMidData_->getBuffer(), cbptr, 10240);
735     for(int i=0; i<10240; i++)
736     {
737       if(isnan(cbptr[i]) || isinf(cbptr[i]))
738       {
739         printf("Warning! %s layer FP mid activations are NaN or Inf\n", nname_.c_str());
740         libxsmm_bfloat16 *ptr = (libxsmm_bfloat16*)tenMidData_->getBuffer();
741         printf("cbptr[%d] = %d, cbptr[%d] = %f\n",i,ptr[i],i,cbptr[i]);
742         exit(-1);
743       }
744     }
745     convert_bf16_f32((libxsmm_bfloat16*)tenTopData_->getBuffer(), cbptr, 10240);
746     for(int i=0; i<10240; i++)
747     {
748       if(isnan(cbptr[i]) || isinf(cbptr[i]))
749       {
750         printf("Warning! %s layer FP activations are NaN or Inf\n", nname_.c_str());
751         libxsmm_bfloat16 *ptr = (libxsmm_bfloat16*)tenTopData_->getBuffer();
752         printf("cbptr[%d] = %d, cbptr[%d] = %f\n",i,ptr[i],i,cbptr[i]);
753         exit(-1);
754       }
755     }
756   }
757 #endif
758 
759 #ifdef GETSTATS
760 #ifdef USE_MLSL
761   unsigned int node_id = MLSL::Environment::GetEnv().GetProcessIdx();
762 #else
763   unsigned int node_id = 0;
764 #endif
765   if(node_id == 0)
766   {
767     if(in_dtype == DT_FLOAT)
768     {
769       float *ptr = (float*)tenBotData_[0]->getBuffer();
770       string s = nname_ + "_r_Inp";
771       MeanOfLayer((char*)s.c_str(), ptr, nImg*ifm0*ifhp*ifwp);
772 
773       if(gparams_.nInput.size() > 1)
774       {
775         ptr = (float*)tenBotData_[1]->getBuffer();
776         s = nname_ + "_l_Inp";
777         MeanOfLayer((char*)s.c_str(), ptr, nImg*ifm1*ifhp*ifwp);
778       }
779 
780       ptr = (float*)tenMidData_->getBuffer();
781       s = nname_ + "_mid";
782       MeanOfLayer((char*)s.c_str(), ptr, nImg*ofm*mfhp*mfwp);
783     }
784     else if(in_dtype == DT_BF16)
785     {
786       if(stptr == NULL)
787       {
788         int s = nImg*ofm*ofhp*ofwp;
789         int ms = nImg*ofm*mfhp*mfwp;
790         int is = nImg*ifm0*ifhp*ifwp;
791         int is1=0;
792         if(gparams_.nInput.size() > 1)
793           is1 = nImg*ifm1*ifhp*ifwp;
794 
795         int size = s > ms ? s : ms;
796         size = size > is ? size : is;
797         size = size > is1 ? size : is1;
798 
799         stptr = (float*)libxsmm_aligned_malloc(size*sizeof(float), 2097152);
800       }
801 
802       libxsmm_bfloat16 *ptr;
803       if(tenBotData_[0]->getLPBuffer() != NULL)
804         ptr = (libxsmm_bfloat16*)tenBotData_[0]->getLPBuffer();
805       else
806         ptr = (libxsmm_bfloat16*)tenBotData_[0]->getBuffer();
807 
808       string s = nname_ + "_r_Inp";
809       convert_bf16_f32(ptr, stptr, nImg*ifm0*ifhp*ifwp);
810       MeanOfLayer((char*)s.c_str(), stptr, nImg*ifm0*ifhp*ifwp);
811 
812       if(gparams_.nInput.size() > 1)
813       {
814         if(tenBotData_[1]->getLPBuffer() != NULL)
815           ptr = (libxsmm_bfloat16*)tenBotData_[1]->getLPBuffer();
816         else
817           ptr = (libxsmm_bfloat16*)tenBotData_[1]->getBuffer();
818 
819         convert_bf16_f32(ptr, stptr, nImg*ifm1*ifhp*ifwp);
820         s = nname_ + "_l_Inp";
821         MeanOfLayer((char*)s.c_str(), stptr, nImg*ifm1*ifhp*ifwp);
822       }
823 
824       ptr = (libxsmm_bfloat16*)tenMidData_->getBuffer();
825       convert_bf16_f32(ptr, stptr, nImg*ofm*mfhp*mfwp);
826       s = nname_ + "_mid";
827       MeanOfLayer((char*)s.c_str(), stptr, nImg*ofm*mfhp*mfwp);
828     }
829 
830     string s = nname_ + "_wt";
831     float* wt = (float*)tenWeightData_->getBuffer();
832     MeanOfLayer((char*)s.c_str(), wt, ifm0*ofm*kh*kw);
833 
834     s = nname_ + "_gammap";
835     float* gamma = (float*)tenScaleData_->getBuffer();
836     MeanOfLayer((char*)s.c_str(), gamma, gparams_.nOutput);
837 
838     s = nname_ + "_betap";
839     float* beta = (float*)tenShiftData_->getBuffer();
840     MeanOfLayer((char*)s.c_str(), beta, gparams_.nOutput);
841 
842     if(out_dtype == DT_FLOAT)
843     {
844       float *ptr = (float*)tenTopData_->getBuffer();
845       string s = nname_ + "_Outp";
846       int size = nImg*ofm*(ofh/bnsh + 2*oph)*(ofw/bnsw + 2*opw);
847       MeanOfLayer((char*)s.c_str(), ptr, size);
848     }
849     else if(out_dtype == DT_BF16)
850     {
851       libxsmm_bfloat16 *ptr = (libxsmm_bfloat16*)tenTopData_->getBuffer();
852       s = nname_ + "_Outp";
853       int size = nImg*ofm*(ofh/bnsh + 2*oph)*(ofw/bnsw + 2*opw);
854       convert_bf16_f32(ptr, stptr, size);
855       MeanOfLayer((char*)s.c_str(), stptr, size);
856     }
857   }
858 #endif
859 }
860 
backPropagate()861 void FusedConvBNNode::backPropagate()
862 {
863 
864   int nImg = gparams_.batch_size;
865   int ifm0 = gparams_.nInput[0];
866   int ifm1 = gparams_.eltwise ? gparams_.nInput[1] : 0;
867   int ofm = gparams_.nOutput;
868   int ifh = gparams_.iHeight;
869   int ifhp = ifh + 2*gparams_.ipad_h;
870   int ifw = gparams_.iWidth;
871   int ifwp = ifw + 2*gparams_.ipad_w;
872   int mfh = gparams_.mHeight;
873   int mfw = gparams_.mWidth;
874   int mfhp = mfh + 2*gparams_.mpad_h;
875   int mfwp = mfw + 2*gparams_.mpad_w;
876   int ofh = gparams_.oHeight;
877   int ofw = gparams_.oWidth;
878   int ofhp = ofh + 2*gparams_.opad_h;
879   int ofwp = ofw + 2*gparams_.opad_w;
880   int kh = gparams_.kh;
881   int kw = gparams_.kw;
882 
883 #ifdef DEBUG
884   printf("Executing BP %s\n",NNNode::nname_.c_str());
885   printf("Grad Outputs: %d x %d x %d\n", ofm, ofh, ofw);
886   printf("Grad Inputs: %d x %d x %d\n", ifm, ifh, ifw);
887   printf("Weights: %d x %d x %d x %d\n", ofm, ifm, kh, kw);
888 #endif
889 
890   tenTopDiff_ = tenTop_->getBuf(DIFF);
891 
892   if(first_bp)
893   {
894     int bsize0 = nImg*ifm0*ifhp*ifwp;
895     int bsize1 = nImg*ifm1*ifhp*ifwp;
896     int msize = nImg*ofm*mfhp*mfwp;
897 
898     if(in_dtype == DT_FLOAT)
899     {
900       float* ptr = (float*)tenBotDiff_[0]->getBuffer();
901       tenMidDiff_->setBuffer(tenBotDiff_[0]->getBuffer() + bsize0*sizeof(float));
902       tenMidDiff_->setBufferSize(msize*sizeof(float));
903       tenBotDiff_[0]->setBufferSize(bsize0*sizeof(float));
904       if(gparams_.eltwise)
905         tenBotDiff_[1]->setBufferSize(bsize1*sizeof(float));
906 
907       // NUMA initialize Conv delinp
908 #ifdef _OPENMP
909 #pragma omp parallel for
910 #endif
911       for(int i=0; i<bsize0; i++)
912         ptr[i] = 0;
913 
914       // NUMA initialize BN delinp = Conv delmidp
915       ptr = (float*)tenMidDiff_->getBuffer();
916 
917 #ifdef _OPENMP
918 #pragma omp parallel for
919 #endif
920       for(int i=0; i<msize; i++)
921         ptr[i] = 0;
922 
923       ptr = gparams_.eltwise ? (float*)tenBotDiff_[1]->getBuffer() : NULL;
924       if(ptr)
925       {
926 #ifdef _OPENMP
927 #pragma omp parallel for
928 #endif
929         for(int i=0; i<bsize1; i++)
930           ptr[i] = 0;
931       }
932     }
933     else if(in_dtype == DT_BF16)
934     {
935       libxsmm_bfloat16* ptr = (libxsmm_bfloat16*)tenBotDiff_[0]->getBuffer();
936       tenMidDiff_->setBuffer(tenBotDiff_[0]->getBuffer() + bsize0*sizeof(libxsmm_bfloat16));
937       tenMidDiff_->setBufferSize(msize*sizeof(libxsmm_bfloat16));
938       tenBotDiff_[0]->setBufferSize(bsize0*sizeof(libxsmm_bfloat16));
939       if(gparams_.eltwise)
940         tenBotDiff_[1]->setBufferSize(bsize1*sizeof(libxsmm_bfloat16));
941 
942       // NUMA initialize Conv delinp
943 #ifdef _OPENMP
944 #pragma omp parallel for
945 #endif
946       for(int i=0; i<bsize0; i++)
947         ptr[i] = 0;
948 
949       // NUMA initialize BN delinp = Conv delmidp
950       ptr = (libxsmm_bfloat16*)tenMidDiff_->getBuffer();
951 
952 #ifdef _OPENMP
953 #pragma omp parallel for
954 #endif
955       for(int i=0; i<msize; i++)
956         ptr[i] = 0;
957 
958       ptr = gparams_.eltwise ? (libxsmm_bfloat16*)tenBotDiff_[1]->getBuffer() : NULL;
959       if(ptr)
960       {
961 #ifdef _OPENMP
962 #pragma omp parallel for
963 #endif
964         for(int i=0; i<bsize1; i++)
965           ptr[i] = 0;
966       }
967     }
968     first_bp = false;
969   }
970 
971   impl->backPropagate(tenTopDiff_, tenWeightData_, tenScaleDiff_, tenShiftDiff_, tenMidDiff_, tenBotDiff_, 0);
972 
973 #ifdef CHECK_BLOWUP_FP32
974   float* cbptr = (float*)tenTopDiff_->getBuffer();
975   for(int i=0; i<10240; i++)
976   {
977     if(isnan(cbptr[i]) || isinf(cbptr[i]))
978     {
979       printf("Warning! %s layer BP activations are NaN or Inf\n", nname_.c_str());
980       exit(-1);
981     }
982   }
983 #endif
984 
985 #ifdef GETSTATS
986   float *ptr, *pptr, *p, *bias;
987 #ifdef USE_MLSL
988   unsigned int node_id_ = MLSL::Environment::GetEnv().GetProcessIdx();
989 #else
990   unsigned int node_id_ = 0;
991 #endif
992   if(node_id_ == 0)
993   {
994     int sh = gparams_.bn_stride_h;
995     int sw = gparams_.bn_stride_w;
996     int ph = gparams_.opad_h;
997     int pw = gparams_.opad_w;
998 
999     if(out_dtype == DT_FLOAT)
1000     {
1001       float *ptr = (float*)tenTopDiff_->getBuffer();
1002 
1003       int size = nImg*ofm*ofhp*ofwp;
1004       string s = nname_ + "_delOutp";
1005       MeanOfLayer((char*)s.c_str(), ptr, size);
1006     }
1007     else if(out_dtype == DT_BF16)
1008     {
1009       libxsmm_bfloat16 *ptr = (libxsmm_bfloat16*)tenTopDiff_->getBuffer();
1010       int size = nImg*ofm*ofhp*ofwp;
1011       convert_bf16_f32(ptr, stptr, size);
1012       string s = nname_ + "_delOutp";
1013       MeanOfLayer((char*)s.c_str(), stptr, size);
1014     }
1015 
1016     string s = nname_ + "_delgammap";
1017     float* delgamma = (float*)tenScaleDiff_->getBuffer();
1018     MeanOfLayer((char*)s.c_str(), delgamma, gparams_.nOutput);
1019 
1020     s = nname_ + "_delbetap";
1021     float* delbeta = (float*)tenShiftDiff_->getBuffer();
1022     MeanOfLayer((char*)s.c_str(), delbeta, gparams_.nOutput);
1023 
1024     if(in_dtype == DT_FLOAT)
1025     {
1026       float *ptr = (float*)tenBotDiff_[0]->getBuffer();
1027       string s = nname_ + "_delInp";
1028       int size = nImg*ifm0*ifhp*ifwp;
1029       MeanOfLayer((char*)s.c_str(), ptr, size);
1030     }
1031     else if(in_dtype == DT_BF16)
1032     {
1033       libxsmm_bfloat16 *ptr = (libxsmm_bfloat16*)tenBotDiff_[0]->getBuffer();
1034       s = nname_ + "_delInp";
1035       int size = nImg*ifm0*ifhp*ifwp;
1036       convert_bf16_f32(ptr, stptr, size);
1037       MeanOfLayer((char*)s.c_str(), stptr, size);
1038     }
1039   }
1040 #endif
1041 }
1042 
weightUpdate()1043 void FusedConvBNNode::weightUpdate()
1044 {
1045   int nImg = gparams_.batch_size;
1046   int ifm0 = gparams_.nInput[0];
1047   int ofm = gparams_.nOutput;
1048   int ifh = gparams_.iHeight;
1049   int ifw = gparams_.iWidth;
1050   int mfh = gparams_.mHeight;
1051   int mfw = gparams_.mWidth;
1052   int mfhp = mfh + 2*gparams_.mpad_h;
1053   int mfwp = mfw + 2*gparams_.mpad_w;
1054   int ifhp = ifh + 2*gparams_.ipad_h;
1055   int ifwp = ifw + 2*gparams_.ipad_w;
1056   int kh = gparams_.kh;
1057   int kw = gparams_.kw;
1058 
1059 #ifdef DEBUG
1060   // printf("Executing WU %s: grad_output %p, grad_weights %p, input %p\n",NNNode::nname_.c_str(), gtop, gwt, bot);
1061   printf("Executing WU %s\n",NNNode::nname_.c_str());
1062   printf("Grad Outputs: %d x %d x %d\n",ofm, ofh,ofw);
1063   printf("Inputs: %d x %d x %d\n",ifm0, ifh, ifw);
1064   printf("del-Weights: %d x %d x %d x %d\n", ofm, ifm0, kh, kw);
1065   printf("del-Biases: %d\n", ofm);
1066 #endif
1067 
1068 #ifdef GETSTATS
1069 #ifdef USE_MLSL
1070   int node_id = MLSL::Environment::GetEnv().GetProcessIdx();
1071 #else
1072   int node_id = 0;
1073 #endif
1074   if(node_id == 0)
1075   {
1076     if(in_dtype == DT_FLOAT)
1077     {
1078       string s = nname_ + "_delWt_Bef";
1079       float *ptr = (float*)tenWeightDiff_->getBuffer();
1080       MeanOfLayer((char*)s.c_str(), ptr, ifm0*ofm*kh*kw);
1081     }
1082     else if(in_dtype == DT_BF16)
1083     {
1084       string s = nname_ + "_delWt_Bef";
1085       libxsmm_bfloat16 *ptr = (libxsmm_bfloat16*)tenWeightDiff_->getBuffer();
1086       memset(stptr, 0, ifm0*ofm*kh*kw);
1087       convert_bf16_f32(ptr, stptr, ifm0*ofm*kh*kw);
1088       MeanOfLayer((char*)s.c_str(), stptr, ifm0*ofm*kh*kw);
1089     }
1090   }
1091 #endif
1092 
1093   if(!bp_flag_ && first_upd)
1094   {
1095     int msize = nImg*ofm*mfhp*mfwp;
1096 
1097     if(in_dtype == DT_FLOAT)
1098     {
1099       float *ptr = (float*)tenMidDiff_->getBuffer();
1100 
1101       // NUMA initialize Conv delmidp
1102 #ifdef _OPENMP
1103 #pragma omp parallel for
1104 #endif
1105       for(int i=0; i<msize; i++)
1106         ptr[i] = 0;
1107     }
1108     else if(in_dtype == DT_BF16)
1109     {
1110       libxsmm_bfloat16* ptr = (libxsmm_bfloat16*)tenMidDiff_->getBuffer();
1111 
1112       // NUMA initialize = Conv delmidp
1113       ptr = (libxsmm_bfloat16*)tenMidDiff_->getBuffer();
1114 
1115 #ifdef _OPENMP
1116 #pragma omp parallel for
1117 #endif
1118       for(int i=0; i<msize; i++)
1119         ptr[i] = 0;
1120     }
1121     first_upd = false;
1122   }
1123 
1124   tenTopDiff_ = tenTop_->getBuf(DIFF);
1125   impl->weightUpdate(tenBotData_[0], tenTopDiff_, tenMidDiff_, tenWeightDiff_, tenScaleDiff_, tenShiftDiff_, 0);
1126 
1127 #ifdef CHECK_BLOWUP_FP32
1128   if(out_dtype == DT_FLOAT)
1129   {
1130     for(int i=0; i<16; i++)
1131     {
1132       float v = ((float*)tenWeightDiff_->getBuffer())[i];
1133       if(isnan(v) || isinf(v))
1134       {
1135         printf("Warning! %s layer BP activations are NaN or Inf\n", nname_.c_str());
1136         exit(-1);
1137       }
1138     }
1139   }
1140   else if(out_dtype == DT_BF16)
1141   {
1142     convert_bf16_f32((libxsmm_bfloat16*)tenWeightDiff_->getBuffer(), cbptr, 16);
1143     for(int i=0; i<16; i++)
1144     {
1145       if(isnan(cbptr[i]) || isinf(cbptr[i]))
1146       {
1147         printf("Warning! %s layer BP activations are NaN or Inf\n", nname_.c_str());
1148         exit(-1);
1149       }
1150     }
1151   }
1152 #endif
1153 
1154   void* gexp[NUM_NUMA_NODES];
1155   void* gvar[NUM_NUMA_NODES];
1156   void* gexp_test = tenMeanData_->getPrivBuffer();
1157   void* gvar_test = tenVarData_->getPrivBuffer();
1158 
1159   void **mptrptr = tenMeanData_->getBufferPtr();
1160   void **vptrptr = tenVarData_->getBufferPtr();
1161   int offset = tenMeanData_->getOffset();
1162 
1163   for(int n=0; n<NUM_NUMA_NODES; n++)
1164     gexp[n] = mptrptr[n] + offset*sizeof(float);
1165 
1166   offset = tenVarData_->getOffset();
1167   for(int n=0; n<NUM_NUMA_NODES; n++)
1168     gvar[n] = vptrptr[n] + offset*sizeof(float);
1169 
1170 #ifdef USE_MLSL
1171   void *mptr = tenWeightDiff_->getBuffer();
1172 
1173   if(in_dtype == DT_BF16)
1174   {
1175     if(dwptr == NULL)
1176     {
1177       int wsize = ifm0*ofm*kh*kw*sizeof(float);
1178       dwptr = (float*)MLSL::Environment::GetEnv().Alloc(wsize, 2097152);
1179     }
1180     convert_bf16_f32((libxsmm_bfloat16*)mptr, dwptr, ifm0*ofm*kh*kw);
1181     op_->GetParameterSet(0)->StartGradientComm(dwptr);
1182   }
1183   else if(in_dtype == DT_FLOAT)
1184     op_->GetParameterSet(0)->StartGradientComm(mptr);
1185 
1186   op_->GetParameterSet(1)->StartGradientComm(tenScaleDiff_->getBuffer());
1187   op_->GetParameterSet(2)->StartGradientComm(tenShiftDiff_->getBuffer());
1188 
1189   int num_nodes = eptr_->get_num_machines();
1190   for(int i=0; i<ofm; i++)
1191   {
1192     float mtmp = 0.0;
1193     float vtmp = 0.0;
1194 
1195     for(int n=0; n<NUM_NUMA_NODES; n++)
1196     {
1197       mtmp += ((float*)gexp[n])[i];
1198       vtmp += ((float*)gvar[n])[i];
1199     }
1200 
1201     mtmp = mtmp/NUM_NUMA_NODES;
1202     vtmp = vtmp/NUM_NUMA_NODES;
1203 
1204     ((float*)gexp_test)[i] = mtmp/num_nodes;
1205     ((float*)gvar_test)[i] = vtmp/num_nodes;
1206   }
1207   this->op_->GetParameterSet(3)->StartGradientComm(gexp_test);
1208   this->op_->GetParameterSet(4)->StartGradientComm(gvar_test);
1209 #endif
1210 
1211 #ifdef GETSTATS
1212 #ifdef USE_MLSL
1213   node_id = MLSL::Environment::GetEnv().GetProcessIdx();
1214 #else
1215   node_id = 0;
1216 #endif
1217   if(node_id == 0)
1218   {
1219     if(in_dtype == DT_FLOAT)
1220     {
1221       string s = nname_ + "_Inp";
1222       float *ptr = (float*)tenBotData_[0]->getBuffer();
1223       MeanOfLayer((char*)s.c_str(), ptr, nImg*ifm0*ifhp*ifwp);
1224 
1225       s = nname_ + "_delMidp";
1226       ptr = (float*)tenMidDiff_->getBuffer();
1227       MeanOfLayer((char*)s.c_str(), ptr, nImg*ofm*mfhp*mfwp);
1228 
1229       s = nname_ + "_delWt_Aft";
1230       ptr = (float*)tenWeightDiff_->getBuffer();
1231       float *pptr = (float*)tenWeightDiff_->getPrivBuffer();
1232       float *p = (pptr == NULL) ? ptr : pptr;
1233       MeanOfLayer((char*)s.c_str(), p, ifm0*ofm*kh*kw);
1234     }
1235     else if(in_dtype == DT_BF16)
1236     {
1237       string s = nname_ + "_Inp";
1238       libxsmm_bfloat16 *ptr = (libxsmm_bfloat16*)tenBotData_[0]->getBuffer();
1239       memset(stptr, 0, nImg*ifm0*ifhp*ifwp);
1240       convert_bf16_f32(ptr, stptr, nImg*ifm0*ifhp*ifwp);
1241       MeanOfLayer((char*)s.c_str(), stptr, nImg*ifm0*ifhp*ifwp);
1242 
1243       s = nname_ + "_delMidp";
1244       ptr = (libxsmm_bfloat16*)tenMidDiff_->getBuffer();
1245       memset(stptr, 0, nImg*ofm*mfhp*mfwp);
1246       convert_bf16_f32(ptr, stptr, nImg*ofm*mfhp*mfwp);
1247       MeanOfLayer((char*)s.c_str(), stptr, nImg*ofm*mfhp*mfwp);
1248 
1249       s = nname_ + "_delWt_Aft";
1250 #ifdef USE_MLSL
1251       MeanOfLayer((char*)s.c_str(), dwptr, ifm0*ofm*kh*kw);
1252 #else
1253       ptr = (libxsmm_bfloat16*)tenWeightDiff_->getBuffer();
1254       memset(stptr, 0, ifm0*ofm*kh*kw);
1255       convert_bf16_f32(ptr, stptr, ifm0*ofm*kh*kw);
1256       MeanOfLayer((char*)s.c_str(), stptr, ifm0*ofm*kh*kw);
1257 #endif
1258     }
1259   }
1260 #endif
1261 }
1262 
solverStep()1263 void FusedConvBNNode::solverStep()
1264 {
1265 #ifdef USE_MLSL
1266   int ifm = gparams_.nInput[0];
1267   int ofm = gparams_.nOutput;
1268   int kh = gparams_.kh;
1269   int kw = gparams_.kw;
1270 
1271   float *gwt = (float*)(tenWeightDiff_->getBuffer());
1272   float *delgamma = (float*)tenScaleDiff_->getBuffer();
1273   float *delbeta = (float*)tenShiftDiff_->getBuffer();
1274   void* gexp_test = tenMeanData_->getPrivBuffer();
1275   void* gvar_test = tenVarData_->getPrivBuffer();
1276 
1277   int wsize = ifm*ofm*kh*kw;
1278 
1279   void *mptr = op_->GetParameterSet(0)->WaitGradientComm();
1280   if(in_dtype == DT_FLOAT)
1281   {
1282     if(mptr != NULL && mptr != gwt)
1283       memcpy((void*)gwt, mptr, wsize*sizeof(float));
1284   }
1285   else if(in_dtype == DT_BF16)
1286   {
1287     if(mptr != NULL && mptr != dwptr)
1288       memcpy((void*)dwptr, mptr, wsize*sizeof(float));
1289     convert_f32_bf16(dwptr, (libxsmm_bfloat16*)gwt, wsize);
1290   }
1291 
1292   mptr = op_->GetParameterSet(1)->WaitGradientComm();
1293   if(mptr != NULL && mptr != delgamma)
1294       memcpy((void*)delgamma, mptr, ofm*sizeof(float));
1295 
1296   mptr = op_->GetParameterSet(2)->WaitGradientComm();
1297   if(mptr != NULL && mptr != delbeta)
1298       memcpy((void*)delbeta, mptr, ofm*sizeof(float));
1299 
1300   mptr = op_->GetParameterSet(3)->WaitGradientComm();
1301   if(mptr != NULL && mptr != gexp_test)
1302     memcpy((void*)gexp_test, mptr, ofm*sizeof(float));
1303 
1304   mptr = op_->GetParameterSet(4)->WaitGradientComm();
1305   if(mptr != NULL && mptr != gvar_test)
1306     memcpy((void*)gvar_test, mptr, ofm*sizeof(float));
1307 
1308 #ifdef CHECK_BLOWUP_FP32
1309   float* ptr = (float*)tenWeightDiff_->getBuffer();
1310   for(int i=0; i<16; i++)
1311   {
1312     if(isnan(ptr[i]) || isinf(ptr[i]))
1313     {
1314       printf("Warning! %s layer Solver gradients are NaN or Inf\n", nname_.c_str());
1315       exit(-1);
1316     }
1317   }
1318   for(int i=0; i<16; i++)
1319   {
1320     if(isnan(delgamma[i]) || isinf(delgamma[i]))
1321     {
1322       printf("Warning! %s layer Solver gamma gradients are NaN or Inf\n", nname_.c_str());
1323       exit(-1);
1324     }
1325   }
1326   for(int i=0; i<16; i++)
1327   {
1328     if(isnan(delbeta[i]) || isinf(delbeta[i]))
1329     {
1330       printf("Warning! %s layer Solver beta gradients are NaN or Inf\n", nname_.c_str());
1331       exit(-1);
1332     }
1333   }
1334 #endif
1335 #endif
1336 }
1337