1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved. *
3 * This file is part of the LIBXSMM library. *
4 * *
5 * For information on the license, see the LICENSE file. *
6 * Further information: https://github.com/hfp/libxsmm/ *
7 * SPDX-License-Identifier: BSD-3-Clause *
8 ******************************************************************************/
9 /* Sasikanth Avancha, Dhiraj Kalamkar, Alexander Heinecke (Intel Corp.)
10 ******************************************************************************/
11
12 #include <string>
13 #include "Conv.hpp"
14 #include "fillers.hpp"
15
16 #ifdef USE_MLSL
17 #include "mpi.h"
18 #endif
19
20
21 using namespace std;
22 using namespace gxm;
23
ConvNode(ConvParams * p,MLEngine * e)24 ConvNode::ConvNode(ConvParams* p, MLEngine* e): NNNode(p, e)
25 {
26 nname_ = p->get_node_name();
27 ntype_ = p->get_node_type();
28 bottom_ = p->get_bottom_names();
29 top_ = p->get_top_names();
30 bp_flag_ = p->get_bprop_flag();
31 has_weights_ = true;
32 compute_stats_ = p->get_compute_stats();
33 bot_compute_engine_ = p->get_compute_engine();
34
35 assert((bottom_.size() == 1) && (top_.size() == 1));
36 bool bias_term = p->get_bias_term();
37
38 tenTop_ = new Tensor(top_[0]);
39 assert(tenTop_ != NULL);
40 tenTop_->setOwner(this);
41 tenTop_->setType(ACT);
42 tenTopData_ = tenTop_->getBuf(DATA);
43 tenTopData_->setBufferType(DATA);
44
45 #ifndef NDEBUG
46 printf("bottom name %s\n",bottom_[0].c_str());
47 #endif
48
49 if(bottom_[0] == "data")
50 tenBot_ = e->get_tensor(bottom_[0], INPUT);
51 else
52 tenBot_ = e->get_tensor(bottom_[0], ACT);
53
54 assert(tenBot_ != NULL);
55 NNNode *pnn = (NNNode*)tenBot_->getOwner();
56 setPrevNode(pnn);
57 mode_ = pnn->getMode();
58 pnn->set_top_compute_engine(p->get_compute_engine());
59 bot_cengine_ = pnn->get_bot_compute_engine();
60
61 tenBotData_ = tenBot_->getBuf(DATA);
62
63 out_dtype = p->get_data_type();
64 in_dtype = tenBotData_->getDataType();
65
66 tenTopData_->setDataType(out_dtype);
67
68 // Get input tensor shape (bottom)
69 Shape* bs = tenBot_->getShape();
70 assert(bs->ndims <= MAX_DIMS);
71
72 // Create shape of output tensor (top)
73 vector<int> vd = p->get_kernel_dims();
74 vector<int> ovp = p->get_output_pads();
75 vector<int> vp = p->get_pads();
76 vector<int> vs = p->get_strides();
77
78 assert((vd.size() == vp.size()) && (vd.size() == vs.size()) && (vs.size() == ovp.size()));
79
80 shape_setzero(&ts_);
81
82 ts_.ndims = bs->ndims; // Number of dimensions
83 ts_.dims[0] = bs->dims[0]; // Minibatch size
84 ts_.dims[1] = p->get_output(); // Num output feature maps
85 ts_.dims[2] = (bs->dims[2] - vd[0] + 2*vp[0])/vs[0] + 1; // Height
86 ts_.dims[3] = (bs->dims[3] - vd[1] + 2*vp[1])/vs[1] + 1; // Width
87
88 tenTop_->setShape(&ts_);
89
90 long long int tsize;
91 int telem = ts_.dims[0] * ts_.dims[1] * (ts_.dims[2] + 2*ovp[0]) * (ts_.dims[3] + 2*ovp[1]);
92
93 // Buffer space for sum and sum^2
94 int tstats=0;
95 if(compute_stats_)
96 tstats = 2*ts_.dims[0]*ts_.dims[1];
97
98 if(out_dtype == DT_FLOAT)
99 tsize = telem*sizeof(float) + tstats*sizeof(float);
100 else if(out_dtype == DT_BF16)
101 tsize = telem*sizeof(libxsmm_bfloat16) + tstats*sizeof(float);
102
103 tenTopData_->setBufferSize(tsize);
104
105 // Create FP weight tensor
106 weight_ = top_[0] + "_wt";
107 tenWeight_ = new Tensor(weight_);
108 assert(tenWeight_ != NULL);
109 tenWeight_->setOwner(this);
110 tenWeight_->setType(CONVWEIGHT);
111
112 shape_setzero(&ws_);
113
114 ws_.ndims = ts_.ndims; // Number of dimesions
115 ws_.dims[0] = ts_.dims[1]; // Num output feature maps (from top tensor)
116 ws_.dims[1] = bs->dims[1]; // Num input feature maps (from bottom tensor)
117 ws_.dims[2] = vd[0]; // Kernel height
118
119 if(ts_.ndims == 4)
120 {
121 ws_.dims[3] = vd[1]; // Kernel width
122 }
123 else if(ts_.ndims == 5)
124 {
125 ws_.dims[3] = vd[1];
126 ws_.dims[4] = vd[2];
127 }
128
129 tenWeight_->setShape(&ws_);
130 tenWeight_->setBufDataType(DATA, DT_FLOAT);
131 tenWeightData_ = tenWeight_->getBuf(DATA);
132 tenWeightData_->setBufferType(DATA);
133
134 int welem = 1;
135 long long int wsize;
136 for(int i=0; i<ws_.ndims; i++)
137 welem = welem*ws_.dims[i];
138
139 // size of master weights -- FP32
140 wsize = welem*sizeof(float);
141
142 gparams_.num_numa_nodes = NUM_NUMA_NODES;
143 tenWeightData_->setBufferSize(wsize);
144
145 wfiller_type_ = p->get_weight_filler_type();
146 variance_norm_ = p->get_variance_norm();
147 std_ = p->get_std();
148
149 lr_mult_ = p->get_lr_mult();
150 decay_mult_ = p->get_decay_mult();
151
152 // Create bias tensor
153 long long int bisize;
154
155 Shape bis;
156 {
157 if(bias_term)
158 {
159 bias_ = top_[0] + "_bias";
160 tenBias_ = new Tensor(bias_);
161 assert(tenBias_ != NULL);
162 tenBias_->setOwner(this);
163 tenBias_->setType(CONVBIAS);
164
165 shape_setzero(&bis);
166
167 bis.ndims = 1;
168 bis.dims[0] = ts_.dims[1];
169 tenBias_->setShape(&bis);
170 tenBiasData_ = tenBias_->getBuf(DATA);
171 tenBiasData_->setDataType(DT_FLOAT);
172 tenBiasData_->setBufferType(DATA);
173
174 bisize = bis.dims[0];
175 bisize = bisize*sizeof(float); // Biases are always in FP32
176 tenBiasData_->setBufferSize(bisize);
177
178 bfiller_type_ = p->get_bias_filler_type();
179 value_ = p->get_value();
180 }
181 }
182
183 if(!e->is_inference_only()) {
184 if(bp_flag_)
185 {
186 tenBotDiff_ = tenBot_->addBuf(); // DIFF type and index
187 tenBotDiff_->setDataType(in_dtype);
188 tenBotDiff_->setBufferType(DIFF);
189
190 long long int bsize = bs->dims[0] * bs->dims[1] * (bs->dims[2] + 2*vp[0]) * (bs->dims[3] + 2*vp[1]);
191
192 if((in_dtype == DT_FLOAT && out_dtype == DT_FLOAT) ||
193 (in_dtype == DT_BF16 && out_dtype == DT_FLOAT))
194 bsize = bsize*sizeof(float);
195 else if(in_dtype == DT_BF16 && out_dtype == DT_BF16)
196 bsize = bsize*sizeof(libxsmm_bfloat16);
197
198 // Set the size of the input-gradient buffer
199 tenBotDiff_->setBufferSize(bsize);
200 }
201
202 if(has_weights_)
203 {
204 tenWeightDiff_ = tenWeight_->addBuf(); // DIFF type and index
205 tenWeightDiff_->setBufferType(DIFF);
206
207 tenWeightInc_ = tenWeight_->addBuf(); // SHARED type and index
208 tenWeightInc_->setBufferType(HISTORY);
209 tenWeightInc_->setDataType(DT_FLOAT);
210 tenWeightInc_->setBufferSize(welem*sizeof(float));
211
212 if(in_dtype == DT_FLOAT)
213 {
214 tenWeightDiff_->setDataType(DT_FLOAT);
215 tenWeightDiff_->setBufferSize(welem*sizeof(float));
216 }
217 else if(in_dtype == DT_BF16)
218 {
219 tenWeightDiff_->setDataType(DT_BF16);
220 #ifdef BF16_MLSL
221 tenWeightDiff_->setBufferSize(welem*sizeof(libxsmm_bfloat16));
222 #else
223 tenWeightDiff_->setBufferSize(welem*sizeof(float));
224 #endif
225 }
226
227 if(bias_term)
228 {
229 tenBiasDiff_ = tenBias_->addBuf(); // DIFF type and index
230 tenBiasDiff_->setDataType(DT_FLOAT);
231 tenBiasDiff_->setBufferType(DIFF);
232
233 tenBiasInc_ = tenBias_->addBuf(); // SHARED type and index
234 tenBiasInc_->setDataType(DT_FLOAT);
235 tenBiasInc_->setBufferType(HISTORY);
236
237 // Set the size of the weight-gradient buffer and the weight-increment buffer
238 tenBiasDiff_->setBufferSize(bisize);
239 tenBiasInc_->setBufferSize(bisize);
240 }
241 }
242 }
243 else {
244 tenBotDiff_ = NULL;
245 tenWeightDiff_ = NULL;
246 tenWeightInc_ = NULL;
247 tenBiasDiff_ = NULL;
248 tenBiasInc_ = NULL;
249 }
250
251 // Register output tensor in tensor map
252 bool inserted = e->register_tensor(top_[0], ACT, tenTop_);
253 if(!inserted)
254 printf("Warning: Tensor %s already registered\n",top_[0].c_str());
255
256 // Register weight tensor in weight tensor map
257 inserted = e->register_tensor(weight_, CONVWEIGHT, tenWeight_);
258 if(!inserted)
259 printf("Warning: Tensor %s already registered\n",weight_.c_str());
260
261 // Register bias tensor in bias tensor map
262 if(bias_term)
263 {
264 inserted = e->register_tensor(bias_, CONVBIAS, tenBias_);
265 if(!inserted)
266 printf("Warning: Tensor %s already registered\n",bias_.c_str());
267 }
268
269
270 // Setup parameter structure for convolution computation in library
271 gparams_.bdims = bs->ndims;
272 gparams_.tdims = ts_.ndims;
273 gparams_.wdims = ws_.ndims;
274 gparams_.bidims = bis.ndims;
275
276 gparams_.node_name = nname_;
277 gparams_.nInput = bs->dims[1];
278 gparams_.nOutput = ts_.dims[1];
279 gparams_.batch_size = bs->dims[0];
280 gparams_.iHeight = bs->dims[2];
281 gparams_.iWidth = bs->dims[3];
282 gparams_.oHeight = ts_.dims[2];
283 gparams_.oWidth = ts_.dims[3];
284 gparams_.pad_h = vp[0];
285 gparams_.pad_w = vp[1];
286 gparams_.physical_padding = p->get_physical_padding();
287 gparams_.compute_stats = compute_stats_;
288
289 if(gparams_.physical_padding)
290 {
291 gparams_.ipad_h = vp[0];
292 gparams_.ipad_w = vp[1];
293 }
294 else
295 {
296 gparams_.ipad_h = 0;
297 gparams_.ipad_w = 0;
298 }
299
300 if(gparams_.physical_padding)
301 {
302 gparams_.opad_h = ovp[0];
303 gparams_.opad_w = ovp[1];
304 }
305 else
306 {
307 gparams_.opad_h = 0;
308 gparams_.opad_w = 0;
309 }
310
311 gparams_.group = p->get_group();
312 gparams_.stride_h = vs[0];
313 gparams_.stride_w = vs[1];
314 gparams_.kh = ws_.dims[2];
315 gparams_.kw = ws_.dims[3];
316
317 gparams_.bias_term = bias_term;
318 gparams_.relu = p->get_fused_relu();
319 gparams_.bwd_relu = p->get_bwd_relu();
320
321 gparams_.in_data_type = in_dtype;
322 gparams_.out_data_type = out_dtype;
323 gparams_.algType = p->get_algo_type();
324 gparams_.num_threads = e->get_num_threads();
325
326 // get solver
327 solver_ = e->getSolver();
328
329 //get global scratch tensor buffer
330 tenScratchData_ = e->getScratchBuffer();
331
332 // get engine
333 eptr_ = e;
334
335 #ifdef USE_MLSL
336 MLSL::DataType dt = MLSL::DT_FLOAT;
337 MLSL::OperationRegInfo *myRegInfo;
338 MLSL::Session *s = eptr_->get_session();
339 myRegInfo = s->CreateOperationRegInfo(MLSL::OT_CC);
340 myRegInfo->SetName(nname_.c_str());
341 myRegInfo->AddParameterSet(gparams_.nInput*gparams_.nOutput/gparams_.group, gparams_.kw*gparams_.kh, dt, false);
342
343 if(bias_term)
344 myRegInfo->AddParameterSet(gparams_.nOutput, 1, dt, false);
345
346 myRegInfo->Validate();
347 size_t opIdx = s->AddOperation(myRegInfo, e->get_distribution());
348 this->op_ = s->GetOperation(opIdx);
349 s->DeleteOperationRegInfo(myRegInfo);
350 e->get_wtgrad_comms_vec().push_back(op_);
351 #endif
352
353 configure(p->get_compute_engine());
354 }
355
fillWeightBuffers(TensorBuf * tBuf,int buftype,long long int size)356 void ConvNode::fillWeightBuffers(TensorBuf* tBuf, int buftype, long long int size)
357 {
358 void *ptr = tBuf->getBuffer();
359
360 #ifdef USE_MLSL
361 unsigned int node_id = MLSL::Environment::GetEnv().GetProcessIdx();
362 #else
363 unsigned int node_id = 0;
364 #endif
365
366 int ic = gparams_.nInput;
367 int oc = gparams_.nOutput;
368 int kh = gparams_.kh;
369 int kw = gparams_.kw;
370 int g = gparams_.group;
371 int fanin = (ic * kh * kw)/g;
372 int fanout = (oc * kh * kw)/g;
373 int welem = ic * oc * kh * kw;
374
375 if(buftype == DATA)
376 {
377 if(node_id == 0)
378 initBuffer(ptr, variance_norm_, fanin, fanout, welem*sizeof(float), wfiller_type_, std_);
379
380 #ifdef USE_MLSL
381 MPI_Bcast(ptr, welem, MPI_FLOAT, 0, MPI_COMM_WORLD);
382 #endif
383 }
384 else if(buftype == HISTORY || buftype == DIFF)
385 memset(ptr, 0, size);
386 }
387
fillWeightMultipliers(float * lr,float * decay,long long int size)388 void ConvNode::fillWeightMultipliers(float* lr, float* decay, long long int size)
389 {
390 for(int i=0; i < size; i++)
391 {
392 lr[i] = lr_mult_[0];
393 decay[i] = decay_mult_[0];
394 }
395 }
396
fillBiasBuffers(TensorBuf * tBuf,int buftype,long long int size)397 void ConvNode::fillBiasBuffers(TensorBuf* tBuf, int buftype, long long int size)
398 {
399 void *ptr = tBuf->getBuffer();
400
401 if(buftype == DATA)
402 {
403 initConstantBuffer(ptr, size, "CONSTANT", value_);
404 }
405 else
406 memset(ptr, 0, size);
407 }
408
fillBiasMultipliers(float * lr,float * decay,long long int size)409 void ConvNode::fillBiasMultipliers(float* lr, float* decay, long long int size)
410 {
411 if(gparams_.bias_term)
412 {
413 for(int i=0; i < size; i++)
414 {
415 lr[i] = lr_mult_[1];
416 decay[i] = decay_mult_[1];
417 }
418 }
419 }
420
Checkpoint(TensorBuf * tBuf,string name,string format)421 void ConvNode::Checkpoint(TensorBuf *tBuf, string name, string format)
422 {
423 long long int bytes = tBuf->getBufferSize();
424 int dtype = tBuf->getDataType();
425 int buftype = tBuf->getBufferType();
426
427 FILE* f;
428 void* ptr;
429 size_t pos;
430
431 if((name.find("30") == name.npos) && (name.find("60") == name.npos) && (name.find("80") == name.npos))
432 while((pos = name.find("/", 10)) != name.npos)
433 name.replace(pos, 1, 1, '_');
434
435 float* p = (float*)tBuf->getBuffer();
436 bool no_checkpt = false;
437 for(int i=0; i<16; i++)
438 {
439 if(isnan(p[i]) || isinf(p[i]))
440 {
441 no_checkpt = true;
442 printf("Warning! %s Did not checkpoint! Weights are NaNs or Inf\n", nname_.c_str());
443 break;
444 }
445 }
446
447 if(!no_checkpt)
448 {
449 if(format == "binary")
450 {
451 f = fopen(name.c_str(), "wb");
452 if(f != NULL)
453 {
454 #if 0
455 if(name.find("wt") != name.npos)
456 {
457 ptr = _mm_malloc(bytes, 64);
458 assert(ptr != NULL);
459 impl->dumpBuffer(tBuf, ptr);
460 }
461 else
462 #endif
463 ptr = tBuf->getBuffer();
464
465 size_t b = fwrite(ptr, 1, bytes, f);
466 assert((long long int)b == bytes);
467
468 #if 0
469 if(name.find("wt") != name.npos)
470 _mm_free(ptr);
471 #endif
472 }
473 else
474 printf("Warning: could not checkpoint to file %s\n",name.c_str());
475 }
476 else
477 {
478 f = fopen(name.c_str(), "w");
479 if(f != NULL)
480 {
481 #if 0
482 if(name.find("wt") != name.npos)
483 {
484 ptr = _mm_malloc(bytes, 64);
485 assert(ptr != NULL);
486 impl->dumpBuffer(tBuf, ptr);
487 }
488 else
489 #endif
490 ptr = tBuf->getBuffer();
491
492 for(int i=0; i<bytes/sizeof(float); i++)
493 fprintf(f, "%f\n", *((float*)ptr + i));
494
495 #if 0
496 if(name.find("wt") != name.npos)
497 _mm_free(ptr);
498 #endif
499 }
500 else
501 printf("Warning: could not checkpoint to file %s\n",name.c_str());
502 }
503 if(f != NULL)
504 {
505 fflush(f);
506 fclose(f);
507 }
508 }
509 }
510
configure(int engine)511 void ConvNode::configure(int engine)
512 {
513 switch(engine)
514 {
515 case XSMM:
516 impl = new ConvXSMM(&gparams_, engine);
517 }
518 }
519
convert_f32_bf16(float * in,libxsmm_bfloat16 * out,int len)520 void ConvNode::convert_f32_bf16(float* in, libxsmm_bfloat16* out, int len)
521 {
522 int i;
523
524 #ifdef _OPENMP
525 #pragma omp parallel for private(i)
526 #endif
527 for ( i = 0; i < len; i+=16 ) {
528 __m512 vfp32 = gxm_fp32_to_bfp16_rne_adjustment_avx512f( _mm512_loadu_ps( in+i ) );
529 __m256i vbfp16 = gxm_fp32_to_bfp16_truncate_avx512f( vfp32 );
530 _mm256_storeu_si256( (__m256i*)(out+i), vbfp16 );
531 }
532 }
533
convert_bf16_f32(libxsmm_bfloat16 * in,float * out,int len)534 void ConvNode::convert_bf16_f32(libxsmm_bfloat16* in, float* out, int len)
535 {
536 #if 1
537
538 #ifdef _OPENMP
539 #pragma omp parallel
540 #endif
541 {
542 int tid = omp_get_thread_num();
543 int ntps = gparams_.num_threads/gparams_.num_numa_nodes;
544 int n = tid/ntps;
545 if(n == 0)
546 {
547 int lenv = len/16;
548 int rem = lenv % ntps;
549 int jobs = (rem == 0) ? (lenv/ntps)*16 : ((lenv-rem)/ntps)*16;
550 int tb = (tid*jobs < len) ? tid*jobs : len;
551 int te = ((tid+1)*jobs < len) ? (tid+1)*jobs : len;
552
553 for (int i = tb; i < te; i+=16 ) {
554 __m256i vbfp16 = _mm256_loadu_si256( (const __m256i*)(in+i) );
555 __m512 vfp32 = gxm_bfp16_to_fp32_avx512f( vbfp16 );
556 _mm512_storeu_ps( out+i, vfp32 );
557 }
558
559 //Remainder processing
560 if(tid == 0)
561 {
562 if(rem > 0)
563 {
564 for(int i=ntps*jobs; i<len; i+=16)
565 {
566 __m256i vbfp16 = _mm256_loadu_si256( (const __m256i*)(in+i) );
567 __m512 vfp32 = gxm_bfp16_to_fp32_avx512f( vbfp16 );
568 _mm512_storeu_ps( out+i, vfp32 );
569 }
570 }
571 }
572 }
573 }
574 #else
575
576 #ifdef _OPENMP
577 #pragma omp parallel
578 #endif
579 {
580 int tid = omp_get_thread_num();
581 int ntps = gparams_.num_threads/gparams_.num_numa_nodes;
582 int n = tid/ntps;
583
584 if(n == 0)
585 {
586 union libxsmm_bfloat16_hp delwt_32_0;
587 delwt_32_0.i[0] = 0;
588
589 int jobs = (len % ntps == 0) ? len/ntps : len/ntps + 1;
590 int tb = (tid*jobs < len) ? tid*jobs : len;
591 int te = ((tid+1)*jobs < len) ? (tid+1)*jobs : len;
592
593 for(int j=tb; j<te; j++)
594 {
595 delwt_32_0.i[1] = in[j];
596 out[j] = delwt_32_0.f;
597 }
598 }
599 }
600 #endif
601 }
602
forwardPropagate()603 void ConvNode::forwardPropagate()
604 {
605 int nImg = gparams_.batch_size;
606 int ifm = gparams_.nInput;
607 int ofm = gparams_.nOutput;
608 int ifh = gparams_.iHeight;
609 int ifhp = ifh + 2*gparams_.ipad_h;
610 int ifw = gparams_.iWidth;
611 int ifwp = ifw + 2*gparams_.ipad_w;
612 int ofh = gparams_.oHeight;
613 int ofw = gparams_.oWidth;
614 int ofhp = ofh + 2*gparams_.opad_h;
615 int ofwp = ofw + 2*gparams_.opad_w;
616 int kh = gparams_.kh;
617 int kw = gparams_.kw;
618
619 #ifndef NDEBUG
620 // printf("Executing FP %s: input %p, weights %p, output %p\n",NNNode::nname_.c_str(), bot, wt, top);
621 printf("Executing FP %s\n",NNNode::nname_.c_str());
622 printf("Inputs: %d x %d x %d\n",ifm, ifh, ifw);
623 printf("Outputs: %d x %d x %d\n",ofm, ofh, ofw);
624 printf("Weights: %d x %d x %d x %d\n", ifm, ofm, kh, kw);
625 printf("Bias: %d\n", ofm);
626
627 if (gparams_.relu) printf("Fused relu\n");
628 #endif
629
630 impl->set_top_compute_engine(top_compute_engine_);
631 impl->set_bot_compute_engine(bot_cengine_);
632 impl->set_node_name(nname_);
633 impl->set_scratch_buffer(tenScratchData_);
634
635 long long int size = nImg * ofm * ofhp * ofwp;
636
637 if(first_fp)
638 {
639 if(tenTopData_->getDataType() == DT_FLOAT)
640 {
641 float* ptr = (float*)tenTopData_->getBuffer();
642
643 #ifdef _OPENMP
644 #pragma omp parallel for
645 #endif
646 for(int i=0; i<size; i++)
647 ptr[i] = 0;
648 }
649 else if(tenTopData_->getDataType() == DT_BF16)
650 {
651 libxsmm_bfloat16* ptr = (libxsmm_bfloat16*)tenTopData_->getBuffer();
652
653 #ifdef _OPENMP
654 #pragma omp parallel for
655 #endif
656 for(int i=0; i<size; i++)
657 ptr[i] = 0;
658 }
659
660 #ifdef CHECK_BLOWUP_FP32
661 cbptr = (float*)_mm_malloc(10240*4, 64);
662 #endif
663
664 first_fp = false;
665 }
666
667 if(tenTopData_->getDataType() == DT_FLOAT)
668 {
669 float* ptr = (float*)tenTopData_->getBuffer();
670 if(compute_stats_)
671 {
672 float* sptr = ptr + size;
673
674 /* @TODO move this into Batch Norm/LIBXSMM */
675 #ifdef _OPENMP
676 #pragma omp parallel for
677 #endif
678 for(int i=0; i<2*nImg*ofm; i++)
679 sptr[i] = 0;
680 }
681 }
682 else if(tenTopData_->getDataType() == DT_BF16)
683 {
684 libxsmm_bfloat16* ptr = (libxsmm_bfloat16*)tenTopData_->getBuffer();
685 if(compute_stats_)
686 {
687 libxsmm_bfloat16* sptr = ptr + size;
688
689 /* @TODO move this into Batch Norm/LIBXSMM */
690 #ifdef _OPENMP
691 #pragma omp parallel for
692 #endif
693 for(int i=0; i<2*nImg*ofm; i++)
694 sptr[i] = 0;
695 }
696 }
697
698 impl->forwardPropagate(tenBotData_, tenWeightData_, tenWeightInc_, tenBiasData_, tenTopData_);
699
700 #ifdef CHECK_BLOWUP_FP32
701 if(out_dtype == DT_FLOAT)
702 {
703 for(int i=0; i<16; i++)
704 {
705 float v = ((float*)tenTopData_->getBuffer())[i];
706 if(isnan(v) || isinf(v))
707 {
708 printf("Warning! %s layer FP activations are NaN or Inf\n", nname_.c_str());
709 exit(-1);
710 }
711 }
712 }
713 else if(out_dtype == DT_BF16)
714 {
715 convert_bf16_f32((libxsmm_bfloat16*)tenTopData_->getBuffer(), cbptr, 10240);
716 for(int i=0; i<10240; i++)
717 {
718 if(isnan(cbptr[i]) || isinf(cbptr[i]))
719 {
720 printf("Warning! %s layer FP activations are NaN or Inf\n", nname_.c_str());
721 exit(-1);
722 }
723 }
724 }
725 #endif
726
727 #ifdef GETSTATS
728 #ifdef USE_MLSL
729 unsigned int node_id = MLSL::Environment::GetEnv().GetProcessIdx();
730 if(node_id == 0)
731 #endif
732 {
733 if(out_dtype == DT_FLOAT)
734 {
735 float *ptr, *pptr, *p;
736
737 if(eptr_->get_current_batch() % STATFREQ == 0)
738 {
739 string s = nname_ + "_Inp";
740 ptr = (float*)tenBotData_->getBuffer();
741 pptr = (float*)tenBotData_->getPrivBuffer();
742 p = (pptr == NULL) ? ptr : pptr;
743 MeanOfLayer((char*)s.c_str(), p, nImg*ifm*ifhp*ifwp);
744
745 s = nname_ + "_Wt";
746 ptr = (float*)tenWeightData_->getBuffer();
747 pptr = (float*)tenWeightData_->getPrivBuffer();
748 p = (pptr == NULL) ? ptr : pptr;
749 MeanOfLayer((char*)s.c_str(), p, ifm*ofm*kh*kw);
750
751 if(gparams_.bias_term)
752 {
753 s = nname_ + "_Bias";
754 p = (float*)tenBiasData_->getBuffer();
755 MeanOfLayer((char*)s.c_str(), p, ofm);
756 }
757
758 s = nname_ + "_Outp";
759 ptr = (float*)tenTopData_->getBuffer();
760 pptr = (float*)tenTopData_->getPrivBuffer();
761 p = (pptr == NULL) ? ptr : pptr;
762 MeanOfLayer((char*)s.c_str(), p, nImg*ofm*ofhp*ofwp);
763
764 if(compute_stats_)
765 {
766 s = nname_ + "_sump";
767 int offset = nImg*ofm*ofhp*ofwp*sizeof(float);
768 void* m = (void*)p + offset;
769 MeanOfLayer((char*)s.c_str(), (double*)m, nImg*ofm);
770
771 s = nname_ + "_sum2p";
772 void* m2 = (void*)m + nImg*ofm*sizeof(double);
773 MeanOfLayer((char*)s.c_str(), (double*)m2, nImg*ofm);
774 }
775 }
776 }
777 else if(out_dtype == DT_BF16)
778 {
779 if(stptr == NULL)
780 {
781 int os = nImg*ofm*ofhp*ofwp;
782 int is = nImg*ifm*ifhp*ifwp;
783 int ws = ifm*ofm*kh*kw;
784 int m = os < is ? is : os;
785 int msize = m < ws ? ws : m;
786 stptr = (float*)libxsmm_aligned_malloc(msize*sizeof(float), 2097152);
787 }
788
789 {
790 string s = nname_ + "_Inp";
791 libxsmm_bfloat16 *ptr;
792 if(tenBotData_->getLPBuffer() != NULL)
793 ptr = (libxsmm_bfloat16*)tenBotData_->getLPBuffer();
794 else
795 ptr = (libxsmm_bfloat16*)tenBotData_->getBuffer();
796 convert_bf16_f32(ptr, stptr, nImg*ifm*ifhp*ifwp);
797 MeanOfLayer((char*)s.c_str(), stptr, nImg*ifm*ifhp*ifwp);
798
799 s = nname_ + "_Wt";
800 float *fptr = (float*)tenWeightData_->getBuffer();
801 int w = ifm*ofm*kh*kw;
802 MeanOfLayer((char*)s.c_str(), fptr, w);
803
804 if(gparams_.bias_term)
805 {
806 s = nname_ + "_Bias";
807 float *p = (float*)tenBiasData_->getBuffer();
808 MeanOfLayer((char*)s.c_str(), p, ofm);
809 }
810
811 s = nname_ + "_Outp";
812 ptr = (libxsmm_bfloat16*)tenTopData_->getBuffer();
813 memset(stptr, 0, nImg*ofm*ofhp*ofwp);
814 convert_bf16_f32(ptr, stptr, nImg*ofm*ofhp*ofwp);
815 MeanOfLayer((char*)s.c_str(), stptr, nImg*ofm*ofhp*ofwp);
816
817 if(compute_stats_)
818 {
819 s = nname_ + "_sump";
820 int offset = nImg*ofm*ofhp*ofwp*sizeof(float);
821 void* m = (void*)ptr + offset;
822 MeanOfLayer((char*)s.c_str(), (float*)m, nImg*ofm);
823
824 s = nname_ + "_sum2p";
825 void* m2 = (void*)m + nImg*ofm*sizeof(float);
826 MeanOfLayer((char*)s.c_str(), (float*)m2, nImg*ofm);
827 }
828 }
829 }
830 }
831 #endif
832 }
833
backPropagate()834 void ConvNode::backPropagate()
835 {
836
837 int nImg = gparams_.batch_size;
838 int ifm = gparams_.nInput;
839 int ofm = gparams_.nOutput;
840 int ifh = gparams_.iHeight;
841 int ifhp = ifh + 2*gparams_.ipad_h;
842 int ifw = gparams_.iWidth;
843 int ifwp = ifw + 2*gparams_.ipad_w;
844 int ofh = gparams_.oHeight;
845 int ofw = gparams_.oWidth;
846 int ofhp = ofh + 2*gparams_.opad_h;
847 int ofwp = ofw + 2*gparams_.opad_w;
848 int kh = gparams_.kh;
849 int kw = gparams_.kw;
850
851 #ifdef DEBUG
852 printf("Executing BP %s\n",NNNode::nname_.c_str());
853 printf("Grad Outputs: %d x %d x %d\n", ofm, ofh, ofw);
854 printf("Grad Inputs: %d x %d x %d\n", ifm, ifh, ifw);
855 printf("Weights: %d x %d x %d x %d\n", ofm, ifm, kh, kw);
856 #endif
857
858 tenTopDiff_ = tenTop_->getBuf(DIFF);
859
860 if(first_bp)
861 {
862 long long int size = nImg * ifm * ifhp *ifwp;
863
864 if((in_dtype == DT_BF16 && out_dtype == DT_FLOAT)
865 || (in_dtype == DT_FLOAT && out_dtype == DT_FLOAT))
866 {
867 float* ptr = (float*)tenBotDiff_->getBuffer();
868 #ifdef _OPENMP
869 #pragma omp parallel for
870 #endif
871 for(int i=0; i<size; i++)
872 ptr[i] = 0;
873 }
874 else if(in_dtype == DT_BF16 && out_dtype == DT_BF16)
875 {
876 libxsmm_bfloat16* ptr = (libxsmm_bfloat16*)tenBotDiff_->getBuffer();
877
878 #ifdef _OPENMP
879 #pragma omp parallel for
880 #endif
881 for(int i=0; i<size; i++)
882 ptr[i] = 0;
883 }
884
885 first_bp = false;
886 }
887
888 impl->backPropagate(tenTopData_, tenWeightData_, tenTopDiff_, tenBotDiff_);
889
890 #ifdef CHECK_BLOWUP_FP32
891 if(out_dtype == DT_FLOAT)
892 {
893 for(int i=0; i<10240; i++)
894 {
895 float v = ((float*)tenBotDiff_->getBuffer())[i];
896 if(isnan(v) || isinf(v))
897 {
898 printf("Warning! %s layer BP activations are NaN or Inf\n", nname_.c_str());
899 exit(-1);
900 }
901 }
902 }
903 else if(out_dtype == DT_BF16)
904 {
905 convert_bf16_f32((libxsmm_bfloat16*)tenBotDiff_->getBuffer(), cbptr, 10240);
906 #ifdef USE_MLSL
907 int node_id = MLSL::Environment::GetEnv().GetProcessIdx();
908 #else
909 int node_id = 0;
910 #endif
911 if(node_id == 0)
912 {
913 for(int i=0; i<10240; i++)
914 {
915 if(isnan(cbptr[i]) || isinf(cbptr[i]))
916 {
917 printf("Warning! %s layer BP activations are NaN or Inf\n", nname_.c_str());
918 MeanOfLayer((char*)((nname_+"_delin").c_str()), (libxsmm_bfloat16*)tenBotDiff_->getBuffer(), nImg*ifm*ifhp*ifwp);
919 MeanOfLayer((char*)((nname_+"_delout").c_str()), (libxsmm_bfloat16*)tenTopDiff_->getBuffer(), nImg*ofm*ofhp*ofwp);
920 MeanOfLayer((char*)((nname_+"_weight").c_str()), (libxsmm_bfloat16*)tenWeightData_->getLPBuffer(), ofm*ifm*kh*kw);
921 #ifdef USE_MLSL
922 MPI_Finalize();
923 #endif
924 exit(-1);
925 }
926 }
927 }
928 }
929 #endif
930
931 #ifdef GETSTATS
932 #ifdef USE_MLSL
933 unsigned int node_id_ = MLSL::Environment::GetEnv().GetProcessIdx();
934 if(node_id_ == 0)
935 #endif
936 {
937 if(eptr_->get_current_batch() % STATFREQ == 0)
938 {
939 if(in_dtype == DT_FLOAT && out_dtype == DT_FLOAT)
940 {
941 string s = nname_ + "_delOutp";
942
943 float *ptr = (float*)tenTopDiff_->getBuffer();
944 MeanOfLayer((char*)s.c_str(), ptr, nImg*ofm*ofhp*ofwp);
945
946 s = nname_ + "_Wt";
947 ptr = (float*)tenWeightData_->getBuffer();
948 MeanOfLayer((char*)s.c_str(), ptr, ifm*ofm*kh*kw);
949
950 s = nname_ + "_delInp";
951 ptr = (float*)tenBotDiff_->getBuffer();
952 MeanOfLayer((char*)s.c_str(), ptr, nImg*ifm*ifhp*ifwp);
953 }
954 else if(in_dtype == DT_BF16 && out_dtype == DT_BF16)
955 {
956 string s = nname_ + "_delOutp";
957
958 libxsmm_bfloat16 *ptr = (libxsmm_bfloat16*)tenTopDiff_->getBuffer();
959 memset(stptr, 0, nImg*ofm*ofhp*ofwp);
960 convert_bf16_f32(ptr, stptr, nImg*ofm*ofhp*ofwp);
961 MeanOfLayer((char*)s.c_str(), stptr, nImg*ofm*ofhp*ofwp);
962
963 s = nname_ + "_Wt";
964 float *fptr = (float*)tenWeightData_->getBuffer();
965 MeanOfLayer((char*)s.c_str(), fptr, ifm*ofm*kh*kw);
966
967 s = nname_ + "_delInp";
968 ptr = (libxsmm_bfloat16*)tenBotDiff_->getBuffer();
969 memset(stptr, 0, nImg*ifm*ifhp*ifwp);
970 convert_bf16_f32(ptr, stptr, nImg*ifm*ifhp*ifwp);
971 MeanOfLayer((char*)s.c_str(), stptr, nImg*ifm*ifhp*ifwp);
972 }
973 }
974 }
975 #endif
976 }
977
weightUpdate()978 void ConvNode::weightUpdate()
979 {
980 int nImg = gparams_.batch_size;
981 int ifm = gparams_.nInput;
982 int ofm = gparams_.nOutput;
983 int ifh = gparams_.iHeight;
984 int ifw = gparams_.iWidth;
985 int ofh = gparams_.oHeight;
986 int ofw = gparams_.oWidth;
987 int ofhp = ofh + 2*gparams_.opad_h;
988 int ofwp = ofw + 2*gparams_.opad_w;
989 int ifhp = ifh + 2*gparams_.ipad_h;
990 int ifwp = ifw + 2*gparams_.ipad_w;
991 int kh = gparams_.kh;
992 int kw = gparams_.kw;
993
994 #ifdef DEBUG
995 // printf("Executing WU %s: grad_output %p, grad_weights %p, input %p\n",NNNode::nname_.c_str(), gtop, gwt, bot);
996 printf("Executing WU %s\n",NNNode::nname_.c_str());
997 printf("Grad Outputs: %d x %d x %d\n",ofm, ofh,ofw);
998 printf("Inputs: %d x %d x %d\n",ifm, ifh, ifw);
999 printf("del-Weights: %d x %d x %d x %d\n", ofm, ifm, kh, kw);
1000 printf("del-Biases: %d\n", ofm);
1001 #endif
1002
1003 #ifdef GETSTATS
1004 int node_id = 0;
1005 #ifdef USE_MLSL
1006 node_id = MLSL::Environment::GetEnv().GetProcessIdx();
1007 if(node_id == 0 && eptr_->get_current_batch() % STATFREQ == 0)
1008 #else
1009 if(eptr_->get_current_batch() % STATFREQ == 0)
1010 #endif
1011 {
1012 if(in_dtype == DT_FLOAT)
1013 {
1014 string s = nname_ + "_delWt_Bef";
1015 float *ptr = (float*)tenWeightDiff_->getBuffer();
1016 MeanOfLayer((char*)s.c_str(), ptr, ifm*ofm*kh*kw);
1017 }
1018 else if(in_dtype == DT_BF16)
1019 {
1020 string s = nname_ + "_delWt_Bef";
1021 libxsmm_bfloat16 *ptr = (libxsmm_bfloat16*)tenWeightDiff_->getBuffer();
1022 memset(stptr, 0, ifm*ofm*kh*kw);
1023 convert_bf16_f32(ptr, stptr, ifm*ofm*kh*kw);
1024 MeanOfLayer((char*)s.c_str(), stptr, ifm*ofm*kh*kw);
1025 }
1026
1027 if(gparams_.bias_term)
1028 {
1029 string s = nname_ + "_delBias_Bef";
1030 float *p = (float*)tenBiasDiff_->getBuffer();
1031 MeanOfLayer((char*)s.c_str(), p, ofm);
1032 }
1033 }
1034 #endif
1035
1036 tenTopDiff_ = tenTop_->getBuf(DIFF);
1037
1038 impl->weightUpdate(tenBotData_, tenTopDiff_, tenWeightDiff_, tenBiasDiff_);
1039
1040 #ifdef CHECK_BLOWUP_FP32
1041 if(out_dtype == DT_FLOAT)
1042 {
1043 for(int i=0; i<10240; i++)
1044 {
1045 float v = ((float*)tenWeightDiff_->getBuffer())[i];
1046 if(isnan(v) || isinf(v))
1047 {
1048 printf("Warning! %s layer weight-gradients are NaN or Inf\n", nname_.c_str());
1049 exit(-1);
1050 }
1051 }
1052 }
1053 else if(out_dtype == DT_BF16)
1054 {
1055 #ifdef BF16_MLSL
1056 void **wptrptr = tenWeightDiff_->getBufferPtr();
1057 #else
1058 void **wptrptr = tenWeightDiff_->getLPBufferPtr();
1059 #endif
1060 int offset = tenWeightDiff_->getOffset();
1061 void* bf16_wtdiff = wptrptr[0] + offset*sizeof(libxsmm_bfloat16);
1062
1063 convert_bf16_f32((libxsmm_bfloat16*)bf16_wtdiff, cbptr, 10240);
1064 #ifdef USE_MLSL
1065 int node_id = MLSL::Environment::GetEnv().GetProcessIdx();
1066 #else
1067 int node_id = 0;
1068 #endif
1069 if(node_id == 0)
1070 {
1071 for(int i=0; i<10240; i++)
1072 {
1073 if(isnan(cbptr[i]) || isinf(cbptr[i]))
1074 {
1075 printf("Warning! %s layer weight-gradients are NaN or Inf\n", nname_.c_str());
1076 MeanOfLayer((char*)nname_.c_str(), (libxsmm_bfloat16*)bf16_wtdiff, ofm*ifm*kw*kw);
1077 exit(-1);
1078 }
1079 }
1080 }
1081 }
1082 #endif
1083
1084 #ifdef USE_MLSL
1085 void *mptr = tenWeightDiff_->getBuffer();
1086
1087 #ifndef BF16_MLSL
1088 void *lmptr = tenWeightDiff_->getLPBuffer();
1089
1090 if(in_dtype == DT_BF16)
1091 {
1092 convert_bf16_f32((libxsmm_bfloat16*)lmptr, (float*)mptr, ifm*ofm*kh*kw);
1093 op_->GetParameterSet(0)->StartGradientComm(mptr);
1094 }
1095 else if(in_dtype == DT_FLOAT)
1096 op_->GetParameterSet(0)->StartGradientComm(mptr);
1097 #else
1098 op_->GetParameterSet(0)->StartGradientComm(mptr);
1099 #endif
1100
1101 if(gparams_.bias_term)
1102 op_->GetParameterSet(1)->StartGradientComm(tenBiasDiff_->getBuffer());
1103 #endif
1104
1105 #ifdef GETSTATS
1106 #ifdef USE_MLSL
1107 node_id = MLSL::Environment::GetEnv().GetProcessIdx();
1108 #else
1109 node_id = 0;
1110 #endif
1111 if(node_id == 0)
1112 {
1113 if(in_dtype == DT_FLOAT)
1114 {
1115 string s = nname_ + "_Inp";
1116 float *ptr = (float*)tenBotData_->getBuffer();
1117 MeanOfLayer((char*)s.c_str(), ptr, nImg*ifm*ifhp*ifwp);
1118 s = nname_ + "_delOutp";
1119 ptr = (float*)tenTopDiff_->getBuffer();
1120 MeanOfLayer((char*)s.c_str(), ptr, nImg*ofm*ofhp*ofwp);
1121
1122 s = nname_ + "_delWt_Aft";
1123 ptr = (float*)tenWeightDiff_->getBuffer();
1124 float *pptr = (float*)tenWeightDiff_->getPrivBuffer();
1125 float *p = (pptr == NULL) ? ptr : pptr;
1126 MeanOfLayer((char*)s.c_str(), p, ifm*ofm*kh*kw);
1127 }
1128 else if(in_dtype == DT_BF16)
1129 {
1130 string s = nname_ + "_Inp";
1131 libxsmm_bfloat16 *ptr;
1132 if(tenBotData_->getLPBuffer() != NULL)
1133 ptr = (libxsmm_bfloat16*)tenBotData_->getLPBuffer();
1134 else
1135 ptr = (libxsmm_bfloat16*)tenBotData_->getBuffer();
1136
1137 memset(stptr, 0, nImg*ifm*ifhp*ifwp);
1138 convert_bf16_f32(ptr, stptr, nImg*ifm*ifhp*ifwp);
1139 MeanOfLayer((char*)s.c_str(), stptr, nImg*ifm*ifhp*ifwp);
1140
1141 s = nname_ + "_delOutp";
1142 ptr = (libxsmm_bfloat16*)tenTopDiff_->getBuffer();
1143 memset(stptr, 0, nImg*ofm*ofhp*ofwp);
1144 convert_bf16_f32(ptr, stptr, nImg*ofm*ofhp*ofwp);
1145 MeanOfLayer((char*)s.c_str(), stptr, nImg*ofm*ofhp*ofwp);
1146
1147 s = nname_ + "_delWt_Aft";
1148 ptr = (libxsmm_bfloat16*)tenWeightDiff_->getBuffer();
1149 memset(stptr, 0, ifm*ofm*kh*kw);
1150 convert_bf16_f32(ptr, stptr, ifm*ofm*kh*kw);
1151 MeanOfLayer((char*)s.c_str(), stptr, ifm*ofm*kh*kw);
1152 }
1153
1154 if(gparams_.bias_term)
1155 {
1156 string s = nname_ + "_delBias_Aft";
1157 float *p = (float*)tenBiasDiff_->getBuffer();
1158 MeanOfLayer((char*)s.c_str(), p, ofm);
1159 }
1160 }
1161 #endif
1162 }
1163
solverStep()1164 void ConvNode::solverStep()
1165 {
1166 #ifdef USE_MLSL
1167 int ifm = gparams_.nInput;
1168 int ofm = gparams_.nOutput;
1169 int kh = gparams_.kh;
1170 int kw = gparams_.kw;
1171
1172 void *gwt = tenWeightDiff_->getBuffer();
1173
1174 float *gbias;
1175 if(gparams_.bias_term)
1176 gbias = (float*)(tenBiasDiff_->getBuffer());
1177
1178 int wsize = ifm*ofm*kh*kw;
1179 void *mptr = op_->GetParameterSet(0)->WaitGradientComm();
1180 if(in_dtype == DT_FLOAT)
1181 {
1182 if(mptr != NULL && mptr != gwt)
1183 memcpy((void*)gwt, mptr, wsize*sizeof(float));
1184 }
1185 else if(in_dtype == DT_BF16)
1186 {
1187 if(mptr != NULL && mptr != dwptr)
1188 memcpy((void*)dwptr, mptr, wsize*sizeof(float));
1189 convert_f32_bf16(dwptr, (libxsmm_bfloat16*)gwt, wsize);
1190 }
1191 if(gparams_.bias_term)
1192 {
1193 mptr = op_->GetParameterSet(1)->WaitGradientComm();
1194 if(mptr != NULL && mptr != gbias)
1195 memcpy((void*)gbias, mptr, ofm*sizeof(float));
1196 }
1197 #endif
1198 }
1199