1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved.                      *
3 * This file is part of the LIBXSMM library.                                   *
4 *                                                                             *
5 * For information on the license, see the LICENSE file.                       *
6 * Further information: https://github.com/hfp/libxsmm/                        *
7 * SPDX-License-Identifier: BSD-3-Clause                                       *
8 ******************************************************************************/
9 /* Sasikanth Avancha, Dhiraj Kalamkar, Alexander Heinecke (Intel Corp.)
10 ******************************************************************************/
11 
12 #include <string>
13 #include "Conv.hpp"
14 #include "fillers.hpp"
15 
16 #ifdef USE_MLSL
17 #include "mpi.h"
18 #endif
19 
20 
21 using namespace std;
22 using namespace gxm;
23 
ConvNode(ConvParams * p,MLEngine * e)24 ConvNode::ConvNode(ConvParams* p, MLEngine* e): NNNode(p, e)
25 {
26   nname_ = p->get_node_name();
27   ntype_ = p->get_node_type();
28   bottom_ = p->get_bottom_names();
29   top_ = p->get_top_names();
30   bp_flag_ = p->get_bprop_flag();
31   has_weights_ = true;
32   compute_stats_ = p->get_compute_stats();
33   bot_compute_engine_ = p->get_compute_engine();
34 
35   assert((bottom_.size() == 1) && (top_.size() == 1));
36   bool bias_term = p->get_bias_term();
37 
38   tenTop_ = new Tensor(top_[0]);
39   assert(tenTop_ != NULL);
40   tenTop_->setOwner(this);
41   tenTop_->setType(ACT);
42   tenTopData_ = tenTop_->getBuf(DATA);
43   tenTopData_->setBufferType(DATA);
44 
45 #ifndef NDEBUG
46   printf("bottom name %s\n",bottom_[0].c_str());
47 #endif
48 
49   if(bottom_[0] == "data")
50     tenBot_ = e->get_tensor(bottom_[0], INPUT);
51   else
52     tenBot_ = e->get_tensor(bottom_[0], ACT);
53 
54   assert(tenBot_ != NULL);
55   NNNode *pnn = (NNNode*)tenBot_->getOwner();
56   setPrevNode(pnn);
57   mode_ = pnn->getMode();
58   pnn->set_top_compute_engine(p->get_compute_engine());
59   bot_cengine_ = pnn->get_bot_compute_engine();
60 
61   tenBotData_ = tenBot_->getBuf(DATA);
62 
63   out_dtype = p->get_data_type();
64   in_dtype = tenBotData_->getDataType();
65 
66   tenTopData_->setDataType(out_dtype);
67 
68   // Get input tensor shape (bottom)
69   Shape* bs = tenBot_->getShape();
70   assert(bs->ndims <= MAX_DIMS);
71 
72   // Create shape of output tensor (top)
73   vector<int> vd = p->get_kernel_dims();
74   vector<int> ovp = p->get_output_pads();
75   vector<int> vp = p->get_pads();
76   vector<int> vs = p->get_strides();
77 
78   assert((vd.size() == vp.size()) && (vd.size() == vs.size()) && (vs.size() == ovp.size()));
79 
80   shape_setzero(&ts_);
81 
82   ts_.ndims = bs->ndims; // Number of dimensions
83   ts_.dims[0] = bs->dims[0]; // Minibatch size
84   ts_.dims[1] = p->get_output(); // Num output feature maps
85   ts_.dims[2] = (bs->dims[2] - vd[0] + 2*vp[0])/vs[0] + 1; // Height
86   ts_.dims[3] = (bs->dims[3] - vd[1] + 2*vp[1])/vs[1] + 1; // Width
87 
88   tenTop_->setShape(&ts_);
89 
90   long long int tsize;
91   int telem = ts_.dims[0] * ts_.dims[1] * (ts_.dims[2] + 2*ovp[0]) * (ts_.dims[3] + 2*ovp[1]);
92 
93   // Buffer space for sum and sum^2
94   int tstats=0;
95   if(compute_stats_)
96     tstats = 2*ts_.dims[0]*ts_.dims[1];
97 
98   if(out_dtype == DT_FLOAT)
99     tsize = telem*sizeof(float) + tstats*sizeof(float);
100   else if(out_dtype == DT_BF16)
101     tsize = telem*sizeof(libxsmm_bfloat16) + tstats*sizeof(float);
102 
103   tenTopData_->setBufferSize(tsize);
104 
105   // Create FP weight tensor
106   weight_ = top_[0] + "_wt";
107   tenWeight_ = new Tensor(weight_);
108   assert(tenWeight_ != NULL);
109   tenWeight_->setOwner(this);
110   tenWeight_->setType(CONVWEIGHT);
111 
112   shape_setzero(&ws_);
113 
114   ws_.ndims = ts_.ndims;      // Number of dimesions
115   ws_.dims[0] = ts_.dims[1];  // Num output feature maps (from top tensor)
116   ws_.dims[1] = bs->dims[1];  // Num input feature maps (from bottom tensor)
117   ws_.dims[2] = vd[0];        // Kernel height
118 
119   if(ts_.ndims == 4)
120   {
121     ws_.dims[3] = vd[1]; // Kernel width
122   }
123   else if(ts_.ndims == 5)
124   {
125     ws_.dims[3] = vd[1];
126     ws_.dims[4] = vd[2];
127   }
128 
129   tenWeight_->setShape(&ws_);
130   tenWeight_->setBufDataType(DATA, DT_FLOAT);
131   tenWeightData_ = tenWeight_->getBuf(DATA);
132   tenWeightData_->setBufferType(DATA);
133 
134   int welem = 1;
135   long long int wsize;
136   for(int i=0; i<ws_.ndims; i++)
137     welem = welem*ws_.dims[i];
138 
139   // size of master weights -- FP32
140   wsize = welem*sizeof(float);
141 
142   gparams_.num_numa_nodes = NUM_NUMA_NODES;
143   tenWeightData_->setBufferSize(wsize);
144 
145   wfiller_type_ = p->get_weight_filler_type();
146   variance_norm_ = p->get_variance_norm();
147   std_ = p->get_std();
148 
149   lr_mult_ = p->get_lr_mult();
150   decay_mult_ = p->get_decay_mult();
151 
152   // Create bias tensor
153   long long int bisize;
154 
155   Shape bis;
156   {
157     if(bias_term)
158     {
159       bias_ = top_[0] + "_bias";
160       tenBias_ = new Tensor(bias_);
161       assert(tenBias_ != NULL);
162       tenBias_->setOwner(this);
163       tenBias_->setType(CONVBIAS);
164 
165       shape_setzero(&bis);
166 
167       bis.ndims = 1;
168       bis.dims[0] = ts_.dims[1];
169       tenBias_->setShape(&bis);
170       tenBiasData_ = tenBias_->getBuf(DATA);
171       tenBiasData_->setDataType(DT_FLOAT);
172       tenBiasData_->setBufferType(DATA);
173 
174       bisize = bis.dims[0];
175       bisize = bisize*sizeof(float); // Biases are always in FP32
176       tenBiasData_->setBufferSize(bisize);
177 
178       bfiller_type_ = p->get_bias_filler_type();
179       value_ = p->get_value();
180     }
181   }
182 
183   if(!e->is_inference_only()) {
184     if(bp_flag_)
185     {
186       tenBotDiff_ = tenBot_->addBuf(); // DIFF type and index
187       tenBotDiff_->setDataType(in_dtype);
188       tenBotDiff_->setBufferType(DIFF);
189 
190       long long int bsize = bs->dims[0] * bs->dims[1] * (bs->dims[2] + 2*vp[0]) * (bs->dims[3] + 2*vp[1]);
191 
192       if((in_dtype == DT_FLOAT && out_dtype == DT_FLOAT) ||
193           (in_dtype == DT_BF16 && out_dtype == DT_FLOAT))
194         bsize = bsize*sizeof(float);
195       else if(in_dtype == DT_BF16 && out_dtype == DT_BF16)
196         bsize = bsize*sizeof(libxsmm_bfloat16);
197 
198       // Set the size of the input-gradient buffer
199       tenBotDiff_->setBufferSize(bsize);
200     }
201 
202     if(has_weights_)
203     {
204       tenWeightDiff_ = tenWeight_->addBuf(); // DIFF type and index
205       tenWeightDiff_->setBufferType(DIFF);
206 
207       tenWeightInc_ = tenWeight_->addBuf(); // SHARED type and index
208       tenWeightInc_->setBufferType(HISTORY);
209       tenWeightInc_->setDataType(DT_FLOAT);
210       tenWeightInc_->setBufferSize(welem*sizeof(float));
211 
212       if(in_dtype == DT_FLOAT)
213       {
214         tenWeightDiff_->setDataType(DT_FLOAT);
215         tenWeightDiff_->setBufferSize(welem*sizeof(float));
216       }
217       else if(in_dtype == DT_BF16)
218       {
219         tenWeightDiff_->setDataType(DT_BF16);
220 #ifdef BF16_MLSL
221         tenWeightDiff_->setBufferSize(welem*sizeof(libxsmm_bfloat16));
222 #else
223         tenWeightDiff_->setBufferSize(welem*sizeof(float));
224 #endif
225       }
226 
227       if(bias_term)
228       {
229         tenBiasDiff_ = tenBias_->addBuf(); // DIFF type and index
230         tenBiasDiff_->setDataType(DT_FLOAT);
231         tenBiasDiff_->setBufferType(DIFF);
232 
233         tenBiasInc_ = tenBias_->addBuf(); // SHARED type and index
234         tenBiasInc_->setDataType(DT_FLOAT);
235         tenBiasInc_->setBufferType(HISTORY);
236 
237         // Set the size of the weight-gradient buffer and the weight-increment buffer
238         tenBiasDiff_->setBufferSize(bisize);
239         tenBiasInc_->setBufferSize(bisize);
240       }
241     }
242   }
243   else {
244     tenBotDiff_ = NULL;
245     tenWeightDiff_ = NULL;
246     tenWeightInc_ = NULL;
247     tenBiasDiff_ = NULL;
248     tenBiasInc_ = NULL;
249   }
250 
251   // Register output tensor in tensor map
252   bool inserted = e->register_tensor(top_[0], ACT, tenTop_);
253   if(!inserted)
254     printf("Warning: Tensor %s already registered\n",top_[0].c_str());
255 
256   // Register weight tensor in weight tensor map
257   inserted = e->register_tensor(weight_, CONVWEIGHT, tenWeight_);
258   if(!inserted)
259     printf("Warning: Tensor %s already registered\n",weight_.c_str());
260 
261   // Register bias tensor in bias tensor map
262   if(bias_term)
263   {
264     inserted = e->register_tensor(bias_, CONVBIAS, tenBias_);
265     if(!inserted)
266       printf("Warning: Tensor %s already registered\n",bias_.c_str());
267   }
268 
269 
270   // Setup parameter structure for convolution computation in library
271   gparams_.bdims = bs->ndims;
272   gparams_.tdims = ts_.ndims;
273   gparams_.wdims = ws_.ndims;
274   gparams_.bidims = bis.ndims;
275 
276   gparams_.node_name = nname_;
277   gparams_.nInput = bs->dims[1];
278   gparams_.nOutput = ts_.dims[1];
279   gparams_.batch_size = bs->dims[0];
280   gparams_.iHeight = bs->dims[2];
281   gparams_.iWidth = bs->dims[3];
282   gparams_.oHeight = ts_.dims[2];
283   gparams_.oWidth = ts_.dims[3];
284   gparams_.pad_h = vp[0];
285   gparams_.pad_w = vp[1];
286   gparams_.physical_padding = p->get_physical_padding();
287   gparams_.compute_stats = compute_stats_;
288 
289   if(gparams_.physical_padding)
290   {
291     gparams_.ipad_h = vp[0];
292     gparams_.ipad_w = vp[1];
293   }
294   else
295   {
296     gparams_.ipad_h = 0;
297     gparams_.ipad_w = 0;
298   }
299 
300   if(gparams_.physical_padding)
301   {
302     gparams_.opad_h = ovp[0];
303     gparams_.opad_w = ovp[1];
304   }
305   else
306   {
307     gparams_.opad_h = 0;
308     gparams_.opad_w = 0;
309   }
310 
311   gparams_.group = p->get_group();
312   gparams_.stride_h = vs[0];
313   gparams_.stride_w = vs[1];
314   gparams_.kh = ws_.dims[2];
315   gparams_.kw = ws_.dims[3];
316 
317   gparams_.bias_term = bias_term;
318   gparams_.relu = p->get_fused_relu();
319   gparams_.bwd_relu = p->get_bwd_relu();
320 
321   gparams_.in_data_type = in_dtype;
322   gparams_.out_data_type = out_dtype;
323   gparams_.algType = p->get_algo_type();
324   gparams_.num_threads = e->get_num_threads();
325 
326   // get solver
327   solver_ = e->getSolver();
328 
329   //get global scratch tensor buffer
330   tenScratchData_ = e->getScratchBuffer();
331 
332   // get engine
333   eptr_ = e;
334 
335 #ifdef USE_MLSL
336   MLSL::DataType dt = MLSL::DT_FLOAT;
337   MLSL::OperationRegInfo *myRegInfo;
338   MLSL::Session *s = eptr_->get_session();
339   myRegInfo = s->CreateOperationRegInfo(MLSL::OT_CC);
340   myRegInfo->SetName(nname_.c_str());
341   myRegInfo->AddParameterSet(gparams_.nInput*gparams_.nOutput/gparams_.group, gparams_.kw*gparams_.kh, dt, false);
342 
343   if(bias_term)
344     myRegInfo->AddParameterSet(gparams_.nOutput, 1, dt, false);
345 
346   myRegInfo->Validate();
347   size_t opIdx = s->AddOperation(myRegInfo, e->get_distribution());
348   this->op_ = s->GetOperation(opIdx);
349   s->DeleteOperationRegInfo(myRegInfo);
350   e->get_wtgrad_comms_vec().push_back(op_);
351 #endif
352 
353   configure(p->get_compute_engine());
354 }
355 
fillWeightBuffers(TensorBuf * tBuf,int buftype,long long int size)356 void ConvNode::fillWeightBuffers(TensorBuf* tBuf, int buftype, long long int size)
357 {
358   void *ptr = tBuf->getBuffer();
359 
360 #ifdef USE_MLSL
361   unsigned int node_id = MLSL::Environment::GetEnv().GetProcessIdx();
362 #else
363   unsigned int node_id = 0;
364 #endif
365 
366   int ic = gparams_.nInput;
367   int oc = gparams_.nOutput;
368   int kh = gparams_.kh;
369   int kw = gparams_.kw;
370   int g = gparams_.group;
371   int fanin = (ic * kh * kw)/g;
372   int fanout = (oc * kh * kw)/g;
373   int welem = ic * oc * kh * kw;
374 
375   if(buftype == DATA)
376   {
377     if(node_id == 0)
378       initBuffer(ptr, variance_norm_, fanin, fanout, welem*sizeof(float), wfiller_type_, std_);
379 
380 #ifdef USE_MLSL
381     MPI_Bcast(ptr, welem, MPI_FLOAT, 0, MPI_COMM_WORLD);
382 #endif
383   }
384   else if(buftype == HISTORY || buftype == DIFF)
385     memset(ptr, 0, size);
386 }
387 
fillWeightMultipliers(float * lr,float * decay,long long int size)388 void ConvNode::fillWeightMultipliers(float* lr, float* decay, long long int size)
389 {
390   for(int i=0; i < size; i++)
391   {
392     lr[i] = lr_mult_[0];
393     decay[i] = decay_mult_[0];
394   }
395 }
396 
fillBiasBuffers(TensorBuf * tBuf,int buftype,long long int size)397 void ConvNode::fillBiasBuffers(TensorBuf* tBuf, int buftype, long long int size)
398 {
399   void *ptr = tBuf->getBuffer();
400 
401   if(buftype == DATA)
402   {
403     initConstantBuffer(ptr, size, "CONSTANT", value_);
404   }
405   else
406     memset(ptr, 0, size);
407 }
408 
fillBiasMultipliers(float * lr,float * decay,long long int size)409 void ConvNode::fillBiasMultipliers(float* lr, float* decay, long long int size)
410 {
411   if(gparams_.bias_term)
412   {
413     for(int i=0; i < size; i++)
414     {
415       lr[i] = lr_mult_[1];
416       decay[i] = decay_mult_[1];
417     }
418   }
419 }
420 
Checkpoint(TensorBuf * tBuf,string name,string format)421 void ConvNode::Checkpoint(TensorBuf *tBuf, string name, string format)
422 {
423   long long int bytes = tBuf->getBufferSize();
424   int dtype = tBuf->getDataType();
425   int buftype = tBuf->getBufferType();
426 
427   FILE* f;
428   void* ptr;
429   size_t pos;
430 
431   if((name.find("30") == name.npos) && (name.find("60") == name.npos) && (name.find("80") == name.npos))
432     while((pos = name.find("/", 10)) != name.npos)
433       name.replace(pos, 1, 1, '_');
434 
435   float* p = (float*)tBuf->getBuffer();
436   bool no_checkpt = false;
437   for(int i=0; i<16; i++)
438   {
439     if(isnan(p[i]) || isinf(p[i]))
440     {
441       no_checkpt = true;
442       printf("Warning! %s Did not checkpoint! Weights are NaNs or Inf\n", nname_.c_str());
443       break;
444     }
445   }
446 
447   if(!no_checkpt)
448   {
449     if(format == "binary")
450     {
451       f = fopen(name.c_str(), "wb");
452       if(f != NULL)
453       {
454 #if 0
455         if(name.find("wt") != name.npos)
456         {
457           ptr = _mm_malloc(bytes, 64);
458           assert(ptr != NULL);
459           impl->dumpBuffer(tBuf, ptr);
460         }
461         else
462 #endif
463           ptr = tBuf->getBuffer();
464 
465         size_t b = fwrite(ptr, 1, bytes, f);
466         assert((long long int)b == bytes);
467 
468 #if 0
469         if(name.find("wt") != name.npos)
470           _mm_free(ptr);
471 #endif
472       }
473       else
474         printf("Warning: could not checkpoint to file %s\n",name.c_str());
475     }
476     else
477     {
478       f = fopen(name.c_str(), "w");
479       if(f != NULL)
480       {
481 #if 0
482         if(name.find("wt") != name.npos)
483         {
484           ptr = _mm_malloc(bytes, 64);
485           assert(ptr != NULL);
486           impl->dumpBuffer(tBuf, ptr);
487         }
488         else
489 #endif
490           ptr = tBuf->getBuffer();
491 
492         for(int i=0; i<bytes/sizeof(float); i++)
493           fprintf(f, "%f\n", *((float*)ptr + i));
494 
495 #if 0
496         if(name.find("wt") != name.npos)
497           _mm_free(ptr);
498 #endif
499       }
500       else
501         printf("Warning: could not checkpoint to file %s\n",name.c_str());
502     }
503     if(f != NULL)
504     {
505       fflush(f);
506       fclose(f);
507     }
508   }
509 }
510 
configure(int engine)511 void ConvNode::configure(int engine)
512 {
513   switch(engine)
514   {
515     case XSMM:
516       impl = new ConvXSMM(&gparams_, engine);
517   }
518 }
519 
convert_f32_bf16(float * in,libxsmm_bfloat16 * out,int len)520 void ConvNode::convert_f32_bf16(float* in, libxsmm_bfloat16* out, int len)
521 {
522   int i;
523 
524 #ifdef _OPENMP
525 #pragma omp parallel for private(i)
526 #endif
527   for ( i = 0; i < len; i+=16 ) {
528     __m512  vfp32  = gxm_fp32_to_bfp16_rne_adjustment_avx512f( _mm512_loadu_ps( in+i ) );
529     __m256i vbfp16 = gxm_fp32_to_bfp16_truncate_avx512f( vfp32 );
530     _mm256_storeu_si256( (__m256i*)(out+i), vbfp16 );
531   }
532 }
533 
convert_bf16_f32(libxsmm_bfloat16 * in,float * out,int len)534 void ConvNode::convert_bf16_f32(libxsmm_bfloat16* in, float* out, int len)
535 {
536 #if 1
537 
538 #ifdef _OPENMP
539 #pragma omp parallel
540 #endif
541   {
542     int tid = omp_get_thread_num();
543     int ntps = gparams_.num_threads/gparams_.num_numa_nodes;
544     int n = tid/ntps;
545     if(n == 0)
546     {
547       int lenv = len/16;
548       int rem = lenv % ntps;
549       int jobs = (rem == 0) ? (lenv/ntps)*16 : ((lenv-rem)/ntps)*16;
550       int tb = (tid*jobs < len) ? tid*jobs : len;
551       int te = ((tid+1)*jobs < len) ? (tid+1)*jobs : len;
552 
553       for (int i = tb; i < te; i+=16 ) {
554         __m256i vbfp16    = _mm256_loadu_si256( (const __m256i*)(in+i) );
555         __m512  vfp32     = gxm_bfp16_to_fp32_avx512f( vbfp16 );
556         _mm512_storeu_ps( out+i, vfp32 );
557       }
558 
559       //Remainder processing
560       if(tid == 0)
561       {
562         if(rem > 0)
563         {
564           for(int i=ntps*jobs; i<len; i+=16)
565           {
566             __m256i vbfp16    = _mm256_loadu_si256( (const __m256i*)(in+i) );
567             __m512  vfp32     = gxm_bfp16_to_fp32_avx512f( vbfp16 );
568             _mm512_storeu_ps( out+i, vfp32 );
569           }
570         }
571       }
572     }
573   }
574 #else
575 
576 #ifdef _OPENMP
577 #pragma omp parallel
578 #endif
579   {
580     int tid = omp_get_thread_num();
581     int ntps = gparams_.num_threads/gparams_.num_numa_nodes;
582     int n = tid/ntps;
583 
584     if(n == 0)
585     {
586       union libxsmm_bfloat16_hp delwt_32_0;
587       delwt_32_0.i[0] = 0;
588 
589       int jobs = (len % ntps == 0) ? len/ntps : len/ntps + 1;
590       int tb = (tid*jobs < len) ? tid*jobs : len;
591       int te = ((tid+1)*jobs < len) ? (tid+1)*jobs : len;
592 
593       for(int j=tb; j<te; j++)
594       {
595         delwt_32_0.i[1] = in[j];
596         out[j] = delwt_32_0.f;
597       }
598     }
599   }
600 #endif
601 }
602 
forwardPropagate()603 void ConvNode::forwardPropagate()
604 {
605   int nImg = gparams_.batch_size;
606   int ifm = gparams_.nInput;
607   int ofm = gparams_.nOutput;
608   int ifh = gparams_.iHeight;
609   int ifhp = ifh + 2*gparams_.ipad_h;
610   int ifw = gparams_.iWidth;
611   int ifwp = ifw + 2*gparams_.ipad_w;
612   int ofh = gparams_.oHeight;
613   int ofw = gparams_.oWidth;
614   int ofhp = ofh + 2*gparams_.opad_h;
615   int ofwp = ofw + 2*gparams_.opad_w;
616   int kh = gparams_.kh;
617   int kw = gparams_.kw;
618 
619 #ifndef NDEBUG
620   // printf("Executing FP %s: input %p, weights %p, output %p\n",NNNode::nname_.c_str(), bot, wt, top);
621   printf("Executing FP %s\n",NNNode::nname_.c_str());
622   printf("Inputs: %d x %d x %d\n",ifm, ifh, ifw);
623   printf("Outputs: %d x %d x %d\n",ofm, ofh, ofw);
624   printf("Weights: %d x %d x %d x %d\n", ifm, ofm, kh, kw);
625   printf("Bias: %d\n", ofm);
626 
627   if (gparams_.relu) printf("Fused relu\n");
628 #endif
629 
630   impl->set_top_compute_engine(top_compute_engine_);
631   impl->set_bot_compute_engine(bot_cengine_);
632   impl->set_node_name(nname_);
633   impl->set_scratch_buffer(tenScratchData_);
634 
635   long long int size = nImg * ofm * ofhp * ofwp;
636 
637   if(first_fp)
638   {
639     if(tenTopData_->getDataType() == DT_FLOAT)
640     {
641       float* ptr = (float*)tenTopData_->getBuffer();
642 
643 #ifdef _OPENMP
644 #pragma omp parallel for
645 #endif
646       for(int i=0; i<size; i++)
647         ptr[i] = 0;
648     }
649     else if(tenTopData_->getDataType() == DT_BF16)
650     {
651       libxsmm_bfloat16* ptr = (libxsmm_bfloat16*)tenTopData_->getBuffer();
652 
653 #ifdef _OPENMP
654 #pragma omp parallel for
655 #endif
656       for(int i=0; i<size; i++)
657         ptr[i] = 0;
658     }
659 
660 #ifdef CHECK_BLOWUP_FP32
661     cbptr = (float*)_mm_malloc(10240*4, 64);
662 #endif
663 
664     first_fp = false;
665   }
666 
667   if(tenTopData_->getDataType() == DT_FLOAT)
668   {
669     float* ptr = (float*)tenTopData_->getBuffer();
670     if(compute_stats_)
671     {
672       float* sptr = ptr + size;
673 
674       /* @TODO move this into Batch Norm/LIBXSMM */
675 #ifdef _OPENMP
676 #pragma omp parallel for
677 #endif
678       for(int i=0; i<2*nImg*ofm; i++)
679         sptr[i] = 0;
680     }
681   }
682   else if(tenTopData_->getDataType() == DT_BF16)
683   {
684     libxsmm_bfloat16* ptr = (libxsmm_bfloat16*)tenTopData_->getBuffer();
685     if(compute_stats_)
686     {
687       libxsmm_bfloat16* sptr = ptr + size;
688 
689       /* @TODO move this into Batch Norm/LIBXSMM */
690 #ifdef _OPENMP
691 #pragma omp parallel for
692 #endif
693       for(int i=0; i<2*nImg*ofm; i++)
694         sptr[i] = 0;
695     }
696   }
697 
698   impl->forwardPropagate(tenBotData_, tenWeightData_, tenWeightInc_, tenBiasData_, tenTopData_);
699 
700 #ifdef CHECK_BLOWUP_FP32
701   if(out_dtype == DT_FLOAT)
702   {
703     for(int i=0; i<16; i++)
704     {
705       float v = ((float*)tenTopData_->getBuffer())[i];
706       if(isnan(v) || isinf(v))
707       {
708         printf("Warning! %s layer FP activations are NaN or Inf\n", nname_.c_str());
709         exit(-1);
710       }
711     }
712   }
713   else if(out_dtype == DT_BF16)
714   {
715     convert_bf16_f32((libxsmm_bfloat16*)tenTopData_->getBuffer(), cbptr, 10240);
716     for(int i=0; i<10240; i++)
717     {
718       if(isnan(cbptr[i]) || isinf(cbptr[i]))
719       {
720         printf("Warning! %s layer FP activations are NaN or Inf\n", nname_.c_str());
721         exit(-1);
722       }
723     }
724   }
725 #endif
726 
727 #ifdef GETSTATS
728 #ifdef USE_MLSL
729   unsigned int node_id = MLSL::Environment::GetEnv().GetProcessIdx();
730   if(node_id == 0)
731 #endif
732   {
733     if(out_dtype == DT_FLOAT)
734     {
735       float *ptr, *pptr, *p;
736 
737       if(eptr_->get_current_batch() % STATFREQ == 0)
738       {
739         string s = nname_ + "_Inp";
740         ptr = (float*)tenBotData_->getBuffer();
741         pptr = (float*)tenBotData_->getPrivBuffer();
742         p = (pptr == NULL) ? ptr : pptr;
743         MeanOfLayer((char*)s.c_str(), p, nImg*ifm*ifhp*ifwp);
744 
745         s = nname_ + "_Wt";
746         ptr = (float*)tenWeightData_->getBuffer();
747         pptr = (float*)tenWeightData_->getPrivBuffer();
748         p = (pptr == NULL) ? ptr : pptr;
749         MeanOfLayer((char*)s.c_str(), p, ifm*ofm*kh*kw);
750 
751         if(gparams_.bias_term)
752         {
753           s = nname_ + "_Bias";
754           p = (float*)tenBiasData_->getBuffer();
755           MeanOfLayer((char*)s.c_str(), p, ofm);
756         }
757 
758         s = nname_ + "_Outp";
759         ptr = (float*)tenTopData_->getBuffer();
760         pptr = (float*)tenTopData_->getPrivBuffer();
761         p = (pptr == NULL) ? ptr : pptr;
762         MeanOfLayer((char*)s.c_str(), p, nImg*ofm*ofhp*ofwp);
763 
764         if(compute_stats_)
765         {
766           s = nname_ + "_sump";
767           int offset = nImg*ofm*ofhp*ofwp*sizeof(float);
768           void* m = (void*)p + offset;
769           MeanOfLayer((char*)s.c_str(), (double*)m, nImg*ofm);
770 
771           s = nname_ + "_sum2p";
772           void* m2 = (void*)m + nImg*ofm*sizeof(double);
773           MeanOfLayer((char*)s.c_str(), (double*)m2, nImg*ofm);
774         }
775       }
776     }
777     else if(out_dtype == DT_BF16)
778     {
779       if(stptr == NULL)
780       {
781         int os = nImg*ofm*ofhp*ofwp;
782         int is = nImg*ifm*ifhp*ifwp;
783         int ws = ifm*ofm*kh*kw;
784         int m = os < is ? is : os;
785         int msize = m < ws ? ws : m;
786         stptr = (float*)libxsmm_aligned_malloc(msize*sizeof(float), 2097152);
787       }
788 
789       {
790         string s = nname_ + "_Inp";
791         libxsmm_bfloat16 *ptr;
792         if(tenBotData_->getLPBuffer() != NULL)
793           ptr = (libxsmm_bfloat16*)tenBotData_->getLPBuffer();
794         else
795           ptr = (libxsmm_bfloat16*)tenBotData_->getBuffer();
796         convert_bf16_f32(ptr, stptr, nImg*ifm*ifhp*ifwp);
797         MeanOfLayer((char*)s.c_str(), stptr, nImg*ifm*ifhp*ifwp);
798 
799         s = nname_ + "_Wt";
800         float *fptr = (float*)tenWeightData_->getBuffer();
801         int w = ifm*ofm*kh*kw;
802         MeanOfLayer((char*)s.c_str(), fptr, w);
803 
804         if(gparams_.bias_term)
805         {
806           s = nname_ + "_Bias";
807           float *p = (float*)tenBiasData_->getBuffer();
808           MeanOfLayer((char*)s.c_str(), p, ofm);
809         }
810 
811         s = nname_ + "_Outp";
812         ptr = (libxsmm_bfloat16*)tenTopData_->getBuffer();
813         memset(stptr, 0, nImg*ofm*ofhp*ofwp);
814         convert_bf16_f32(ptr, stptr, nImg*ofm*ofhp*ofwp);
815         MeanOfLayer((char*)s.c_str(), stptr, nImg*ofm*ofhp*ofwp);
816 
817         if(compute_stats_)
818         {
819           s = nname_ + "_sump";
820           int offset = nImg*ofm*ofhp*ofwp*sizeof(float);
821           void* m = (void*)ptr + offset;
822           MeanOfLayer((char*)s.c_str(), (float*)m, nImg*ofm);
823 
824           s = nname_ + "_sum2p";
825           void* m2 = (void*)m + nImg*ofm*sizeof(float);
826           MeanOfLayer((char*)s.c_str(), (float*)m2, nImg*ofm);
827         }
828       }
829     }
830   }
831 #endif
832 }
833 
backPropagate()834 void ConvNode::backPropagate()
835 {
836 
837   int nImg = gparams_.batch_size;
838   int ifm = gparams_.nInput;
839   int ofm = gparams_.nOutput;
840   int ifh = gparams_.iHeight;
841   int ifhp = ifh + 2*gparams_.ipad_h;
842   int ifw = gparams_.iWidth;
843   int ifwp = ifw + 2*gparams_.ipad_w;
844   int ofh = gparams_.oHeight;
845   int ofw = gparams_.oWidth;
846   int ofhp = ofh + 2*gparams_.opad_h;
847   int ofwp = ofw + 2*gparams_.opad_w;
848   int kh = gparams_.kh;
849   int kw = gparams_.kw;
850 
851 #ifdef DEBUG
852   printf("Executing BP %s\n",NNNode::nname_.c_str());
853   printf("Grad Outputs: %d x %d x %d\n", ofm, ofh, ofw);
854   printf("Grad Inputs: %d x %d x %d\n", ifm, ifh, ifw);
855   printf("Weights: %d x %d x %d x %d\n", ofm, ifm, kh, kw);
856 #endif
857 
858   tenTopDiff_ = tenTop_->getBuf(DIFF);
859 
860   if(first_bp)
861   {
862     long long int size = nImg * ifm * ifhp *ifwp;
863 
864     if((in_dtype == DT_BF16 && out_dtype == DT_FLOAT)
865         || (in_dtype == DT_FLOAT && out_dtype == DT_FLOAT))
866     {
867       float* ptr = (float*)tenBotDiff_->getBuffer();
868 #ifdef _OPENMP
869 #pragma omp parallel for
870 #endif
871       for(int i=0; i<size; i++)
872         ptr[i] = 0;
873     }
874     else if(in_dtype == DT_BF16 && out_dtype == DT_BF16)
875     {
876       libxsmm_bfloat16* ptr = (libxsmm_bfloat16*)tenBotDiff_->getBuffer();
877 
878 #ifdef _OPENMP
879 #pragma omp parallel for
880 #endif
881       for(int i=0; i<size; i++)
882         ptr[i] = 0;
883     }
884 
885    first_bp = false;
886   }
887 
888   impl->backPropagate(tenTopData_, tenWeightData_, tenTopDiff_, tenBotDiff_);
889 
890 #ifdef CHECK_BLOWUP_FP32
891   if(out_dtype == DT_FLOAT)
892   {
893     for(int i=0; i<10240; i++)
894     {
895       float v = ((float*)tenBotDiff_->getBuffer())[i];
896       if(isnan(v) || isinf(v))
897       {
898         printf("Warning! %s layer BP activations are NaN or Inf\n", nname_.c_str());
899         exit(-1);
900       }
901     }
902   }
903   else if(out_dtype == DT_BF16)
904   {
905     convert_bf16_f32((libxsmm_bfloat16*)tenBotDiff_->getBuffer(), cbptr, 10240);
906 #ifdef USE_MLSL
907     int node_id = MLSL::Environment::GetEnv().GetProcessIdx();
908 #else
909     int node_id = 0;
910 #endif
911     if(node_id == 0)
912     {
913       for(int i=0; i<10240; i++)
914       {
915         if(isnan(cbptr[i]) || isinf(cbptr[i]))
916         {
917           printf("Warning! %s layer BP activations are NaN or Inf\n", nname_.c_str());
918           MeanOfLayer((char*)((nname_+"_delin").c_str()), (libxsmm_bfloat16*)tenBotDiff_->getBuffer(), nImg*ifm*ifhp*ifwp);
919           MeanOfLayer((char*)((nname_+"_delout").c_str()), (libxsmm_bfloat16*)tenTopDiff_->getBuffer(), nImg*ofm*ofhp*ofwp);
920           MeanOfLayer((char*)((nname_+"_weight").c_str()), (libxsmm_bfloat16*)tenWeightData_->getLPBuffer(), ofm*ifm*kh*kw);
921 #ifdef USE_MLSL
922           MPI_Finalize();
923 #endif
924           exit(-1);
925         }
926       }
927     }
928   }
929 #endif
930 
931 #ifdef GETSTATS
932 #ifdef USE_MLSL
933   unsigned int node_id_ = MLSL::Environment::GetEnv().GetProcessIdx();
934   if(node_id_ == 0)
935 #endif
936   {
937     if(eptr_->get_current_batch() % STATFREQ == 0)
938     {
939       if(in_dtype == DT_FLOAT && out_dtype == DT_FLOAT)
940       {
941         string s = nname_ + "_delOutp";
942 
943         float *ptr = (float*)tenTopDiff_->getBuffer();
944         MeanOfLayer((char*)s.c_str(), ptr, nImg*ofm*ofhp*ofwp);
945 
946         s = nname_ + "_Wt";
947         ptr = (float*)tenWeightData_->getBuffer();
948         MeanOfLayer((char*)s.c_str(), ptr, ifm*ofm*kh*kw);
949 
950         s = nname_ + "_delInp";
951         ptr = (float*)tenBotDiff_->getBuffer();
952         MeanOfLayer((char*)s.c_str(), ptr, nImg*ifm*ifhp*ifwp);
953       }
954       else if(in_dtype == DT_BF16 && out_dtype == DT_BF16)
955       {
956         string s = nname_ + "_delOutp";
957 
958         libxsmm_bfloat16 *ptr = (libxsmm_bfloat16*)tenTopDiff_->getBuffer();
959         memset(stptr, 0, nImg*ofm*ofhp*ofwp);
960         convert_bf16_f32(ptr, stptr, nImg*ofm*ofhp*ofwp);
961         MeanOfLayer((char*)s.c_str(), stptr, nImg*ofm*ofhp*ofwp);
962 
963         s = nname_ + "_Wt";
964         float *fptr = (float*)tenWeightData_->getBuffer();
965         MeanOfLayer((char*)s.c_str(), fptr, ifm*ofm*kh*kw);
966 
967         s = nname_ + "_delInp";
968         ptr = (libxsmm_bfloat16*)tenBotDiff_->getBuffer();
969         memset(stptr, 0, nImg*ifm*ifhp*ifwp);
970         convert_bf16_f32(ptr, stptr, nImg*ifm*ifhp*ifwp);
971         MeanOfLayer((char*)s.c_str(), stptr, nImg*ifm*ifhp*ifwp);
972       }
973     }
974   }
975 #endif
976 }
977 
weightUpdate()978 void ConvNode::weightUpdate()
979 {
980   int nImg = gparams_.batch_size;
981   int ifm = gparams_.nInput;
982   int ofm = gparams_.nOutput;
983   int ifh = gparams_.iHeight;
984   int ifw = gparams_.iWidth;
985   int ofh = gparams_.oHeight;
986   int ofw = gparams_.oWidth;
987   int ofhp = ofh + 2*gparams_.opad_h;
988   int ofwp = ofw + 2*gparams_.opad_w;
989   int ifhp = ifh + 2*gparams_.ipad_h;
990   int ifwp = ifw + 2*gparams_.ipad_w;
991   int kh = gparams_.kh;
992   int kw = gparams_.kw;
993 
994 #ifdef DEBUG
995   // printf("Executing WU %s: grad_output %p, grad_weights %p, input %p\n",NNNode::nname_.c_str(), gtop, gwt, bot);
996   printf("Executing WU %s\n",NNNode::nname_.c_str());
997   printf("Grad Outputs: %d x %d x %d\n",ofm, ofh,ofw);
998   printf("Inputs: %d x %d x %d\n",ifm, ifh, ifw);
999   printf("del-Weights: %d x %d x %d x %d\n", ofm, ifm, kh, kw);
1000   printf("del-Biases: %d\n", ofm);
1001 #endif
1002 
1003 #ifdef GETSTATS
1004   int node_id = 0;
1005 #ifdef USE_MLSL
1006   node_id = MLSL::Environment::GetEnv().GetProcessIdx();
1007   if(node_id == 0 && eptr_->get_current_batch() % STATFREQ == 0)
1008 #else
1009   if(eptr_->get_current_batch() % STATFREQ == 0)
1010 #endif
1011   {
1012     if(in_dtype == DT_FLOAT)
1013     {
1014       string s = nname_ + "_delWt_Bef";
1015       float *ptr = (float*)tenWeightDiff_->getBuffer();
1016       MeanOfLayer((char*)s.c_str(), ptr, ifm*ofm*kh*kw);
1017     }
1018     else if(in_dtype == DT_BF16)
1019     {
1020       string s = nname_ + "_delWt_Bef";
1021       libxsmm_bfloat16 *ptr = (libxsmm_bfloat16*)tenWeightDiff_->getBuffer();
1022       memset(stptr, 0, ifm*ofm*kh*kw);
1023       convert_bf16_f32(ptr, stptr, ifm*ofm*kh*kw);
1024       MeanOfLayer((char*)s.c_str(), stptr, ifm*ofm*kh*kw);
1025     }
1026 
1027     if(gparams_.bias_term)
1028     {
1029       string s = nname_ + "_delBias_Bef";
1030       float *p = (float*)tenBiasDiff_->getBuffer();
1031       MeanOfLayer((char*)s.c_str(), p, ofm);
1032     }
1033   }
1034 #endif
1035 
1036   tenTopDiff_ = tenTop_->getBuf(DIFF);
1037 
1038   impl->weightUpdate(tenBotData_, tenTopDiff_, tenWeightDiff_, tenBiasDiff_);
1039 
1040 #ifdef CHECK_BLOWUP_FP32
1041   if(out_dtype == DT_FLOAT)
1042   {
1043     for(int i=0; i<10240; i++)
1044     {
1045       float v = ((float*)tenWeightDiff_->getBuffer())[i];
1046       if(isnan(v) || isinf(v))
1047       {
1048         printf("Warning! %s layer weight-gradients are NaN or Inf\n", nname_.c_str());
1049         exit(-1);
1050       }
1051     }
1052   }
1053   else if(out_dtype == DT_BF16)
1054   {
1055 #ifdef BF16_MLSL
1056     void **wptrptr = tenWeightDiff_->getBufferPtr();
1057 #else
1058     void **wptrptr = tenWeightDiff_->getLPBufferPtr();
1059 #endif
1060     int offset = tenWeightDiff_->getOffset();
1061     void* bf16_wtdiff = wptrptr[0] + offset*sizeof(libxsmm_bfloat16);
1062 
1063     convert_bf16_f32((libxsmm_bfloat16*)bf16_wtdiff, cbptr, 10240);
1064 #ifdef USE_MLSL
1065     int node_id = MLSL::Environment::GetEnv().GetProcessIdx();
1066 #else
1067     int node_id = 0;
1068 #endif
1069     if(node_id == 0)
1070     {
1071       for(int i=0; i<10240; i++)
1072       {
1073         if(isnan(cbptr[i]) || isinf(cbptr[i]))
1074         {
1075           printf("Warning! %s layer weight-gradients are NaN or Inf\n", nname_.c_str());
1076           MeanOfLayer((char*)nname_.c_str(), (libxsmm_bfloat16*)bf16_wtdiff, ofm*ifm*kw*kw);
1077           exit(-1);
1078         }
1079       }
1080     }
1081   }
1082 #endif
1083 
1084 #ifdef USE_MLSL
1085   void *mptr = tenWeightDiff_->getBuffer();
1086 
1087 #ifndef BF16_MLSL
1088   void *lmptr = tenWeightDiff_->getLPBuffer();
1089 
1090   if(in_dtype == DT_BF16)
1091   {
1092     convert_bf16_f32((libxsmm_bfloat16*)lmptr, (float*)mptr, ifm*ofm*kh*kw);
1093     op_->GetParameterSet(0)->StartGradientComm(mptr);
1094   }
1095   else if(in_dtype == DT_FLOAT)
1096     op_->GetParameterSet(0)->StartGradientComm(mptr);
1097 #else
1098   op_->GetParameterSet(0)->StartGradientComm(mptr);
1099 #endif
1100 
1101   if(gparams_.bias_term)
1102     op_->GetParameterSet(1)->StartGradientComm(tenBiasDiff_->getBuffer());
1103 #endif
1104 
1105 #ifdef GETSTATS
1106 #ifdef USE_MLSL
1107   node_id = MLSL::Environment::GetEnv().GetProcessIdx();
1108 #else
1109   node_id = 0;
1110 #endif
1111   if(node_id == 0)
1112   {
1113     if(in_dtype == DT_FLOAT)
1114     {
1115       string s = nname_ + "_Inp";
1116       float *ptr = (float*)tenBotData_->getBuffer();
1117       MeanOfLayer((char*)s.c_str(), ptr, nImg*ifm*ifhp*ifwp);
1118       s = nname_ + "_delOutp";
1119       ptr = (float*)tenTopDiff_->getBuffer();
1120       MeanOfLayer((char*)s.c_str(), ptr, nImg*ofm*ofhp*ofwp);
1121 
1122       s = nname_ + "_delWt_Aft";
1123       ptr = (float*)tenWeightDiff_->getBuffer();
1124       float *pptr = (float*)tenWeightDiff_->getPrivBuffer();
1125       float *p = (pptr == NULL) ? ptr : pptr;
1126       MeanOfLayer((char*)s.c_str(), p, ifm*ofm*kh*kw);
1127     }
1128     else if(in_dtype == DT_BF16)
1129     {
1130       string s = nname_ + "_Inp";
1131       libxsmm_bfloat16 *ptr;
1132       if(tenBotData_->getLPBuffer() != NULL)
1133         ptr = (libxsmm_bfloat16*)tenBotData_->getLPBuffer();
1134       else
1135         ptr = (libxsmm_bfloat16*)tenBotData_->getBuffer();
1136 
1137       memset(stptr, 0, nImg*ifm*ifhp*ifwp);
1138       convert_bf16_f32(ptr, stptr, nImg*ifm*ifhp*ifwp);
1139       MeanOfLayer((char*)s.c_str(), stptr, nImg*ifm*ifhp*ifwp);
1140 
1141       s = nname_ + "_delOutp";
1142       ptr = (libxsmm_bfloat16*)tenTopDiff_->getBuffer();
1143       memset(stptr, 0, nImg*ofm*ofhp*ofwp);
1144       convert_bf16_f32(ptr, stptr, nImg*ofm*ofhp*ofwp);
1145       MeanOfLayer((char*)s.c_str(), stptr, nImg*ofm*ofhp*ofwp);
1146 
1147       s = nname_ + "_delWt_Aft";
1148       ptr = (libxsmm_bfloat16*)tenWeightDiff_->getBuffer();
1149       memset(stptr, 0, ifm*ofm*kh*kw);
1150       convert_bf16_f32(ptr, stptr, ifm*ofm*kh*kw);
1151       MeanOfLayer((char*)s.c_str(), stptr, ifm*ofm*kh*kw);
1152     }
1153 
1154     if(gparams_.bias_term)
1155     {
1156       string s = nname_ + "_delBias_Aft";
1157       float *p = (float*)tenBiasDiff_->getBuffer();
1158       MeanOfLayer((char*)s.c_str(), p, ofm);
1159     }
1160   }
1161 #endif
1162 }
1163 
solverStep()1164 void ConvNode::solverStep()
1165 {
1166 #ifdef USE_MLSL
1167   int ifm = gparams_.nInput;
1168   int ofm = gparams_.nOutput;
1169   int kh = gparams_.kh;
1170   int kw = gparams_.kw;
1171 
1172   void *gwt = tenWeightDiff_->getBuffer();
1173 
1174   float *gbias;
1175   if(gparams_.bias_term)
1176     gbias = (float*)(tenBiasDiff_->getBuffer());
1177 
1178   int wsize = ifm*ofm*kh*kw;
1179   void *mptr = op_->GetParameterSet(0)->WaitGradientComm();
1180   if(in_dtype == DT_FLOAT)
1181   {
1182     if(mptr != NULL && mptr != gwt)
1183       memcpy((void*)gwt, mptr, wsize*sizeof(float));
1184   }
1185   else if(in_dtype == DT_BF16)
1186   {
1187     if(mptr != NULL && mptr != dwptr)
1188       memcpy((void*)dwptr, mptr, wsize*sizeof(float));
1189     convert_f32_bf16(dwptr, (libxsmm_bfloat16*)gwt, wsize);
1190   }
1191   if(gparams_.bias_term)
1192   {
1193     mptr = op_->GetParameterSet(1)->WaitGradientComm();
1194     if(mptr != NULL && mptr != gbias)
1195       memcpy((void*)gbias, mptr, ofm*sizeof(float));
1196   }
1197 #endif
1198 }
1199