1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved. *
3 * This file is part of the LIBXSMM library. *
4 * *
5 * For information on the license, see the LICENSE file. *
6 * Further information: https://github.com/hfp/libxsmm/ *
7 * SPDX-License-Identifier: BSD-3-Clause *
8 ******************************************************************************/
9 /* Sasikanth Avancha, Dhiraj Kalamkar, Alexander Heinecke (Intel Corp.)
10 ******************************************************************************/
11
12 #include <string>
13 #include "FusedConvBN.hpp"
14 #include "fillers.hpp"
15
16 #ifdef USE_MLSL
17 #include "mpi.h"
18 #endif
19
20
21 using namespace std;
22 using namespace gxm;
23
FusedConvBNNode(FusedConvBNParams * p,MLEngine * e)24 FusedConvBNNode::FusedConvBNNode(FusedConvBNParams* p, MLEngine* e): NNNode(p, e)
25 {
26 nname_ = p->get_node_name();
27 ntype_ = p->get_node_type();
28 bottom_ = p->get_bottom_names();
29 top_ = p->get_top_names();
30 bp_flag_ = p->get_bprop_flag();
31 has_weights_ = true;
32 bot_compute_engine_ = p->get_compute_engine();
33
34 tenTop_ = new Tensor(top_[0]);
35 assert(tenTop_ != NULL);
36 tenTop_->setOwner(this);
37 tenTop_->setType(ACT);
38 tenTopData_ = tenTop_->getBuf(DATA);
39 tenTopData_->setBufferType(DATA);
40
41 tenMid_ = new Tensor("mid_"+top_[0]);
42 assert(tenMid_ != NULL);
43 tenMid_->setOwner(this);
44 tenMid_->setType(ACT);
45 tenMidData_ = tenMid_->getBuf(DATA);
46 tenMidData_->setBufferType(DATA);
47
48 tenBot_.resize(bottom_.size());
49 tenBotData_.resize(bottom_.size());
50
51 for(int i=0; i < bottom_.size(); i++)
52 {
53 #ifndef NDEBUG
54 printf("bottom%d name %s\n",i,bottom_[i].c_str());
55 #endif
56
57 if(bottom_[i] == "data")
58 tenBot_[i] = e->get_tensor(bottom_[i], INPUT);
59 else
60 tenBot_[i] = e->get_tensor(bottom_[i], ACT);
61
62 assert(tenBot_[i] != NULL);
63 NNNode *pnn = (NNNode*)tenBot_[i]->getOwner();
64 setPrevNode(pnn);
65 mode_ = pnn->getMode();
66 pnn->set_top_compute_engine(p->get_compute_engine());
67 bot_cengine_ = pnn->get_bot_compute_engine();
68
69 tenBotData_[i] = tenBot_[i]->getBuf(DATA);
70 }
71
72 in_dtype = tenBotData_[0]->getDataType();
73 out_dtype = p->get_data_type();
74 tenTopData_->setDataType(out_dtype);
75
76 // Get input tensor shape (bottom)
77 Shape* bs = tenBot_[0]->getShape();
78 assert(bs->ndims <= MAX_DIMS);
79
80 // Create shape of output tensor (top)
81 vector<int> vd = p->get_kernel_dims();
82 vector<int> mvp = p->get_mid_pads();
83 vector<int> ovp = p->get_top_pads();
84 vector<int> ivp = p->get_bot_pads();
85 vector<int> vcs = p->get_c_strides();
86 vector<int> vbns = p->get_bn_strides();
87
88 shape_setzero(&ms_);
89 ms_.ndims = bs->ndims; // Number of dimensions
90 ms_.dims[0] = bs->dims[0]; // Minibatch size
91 ms_.dims[1] = p->get_output(); // Num output feature maps
92 ms_.dims[2] = (bs->dims[2] - vd[0] + 2*ivp[0])/vcs[0] + 1; // Height
93 ms_.dims[3] = (bs->dims[3] - vd[1] + 2*ivp[1])/vcs[1] + 1; // Width
94
95 tenMid_->setShape(&ms_);
96
97 shape_setzero(&ts_);
98 ts_.ndims = bs->ndims; // Number of dimensions
99 ts_.dims[0] = bs->dims[0]; // Minibatch size
100 ts_.dims[1] = p->get_output(); // Num output feature maps
101 ts_.dims[2] = ms_.dims[2]/vbns[0]; // Height
102 ts_.dims[3] = ms_.dims[3]/vbns[1]; // Width
103
104 tenTop_->setShape(&ts_);
105
106 long long int tsize;
107 int convelem = ms_.dims[0] * ms_.dims[1] * (ms_.dims[2] + 2*mvp[0]) * (ms_.dims[3] + 2*mvp[1]);
108 int bnelem = ts_.dims[0] * ts_.dims[1] * (ts_.dims[2] + 2*ovp[0]) * (ts_.dims[3] + 2*ovp[1]);
109 int telem = convelem + bnelem;
110
111 if(out_dtype == DT_FLOAT)
112 tsize = telem*sizeof(float);
113 else if(out_dtype = DT_BF16)
114 tsize = telem*sizeof(libxsmm_bfloat16);
115
116 tenTopData_->setBufferSize(tsize);
117
118 // Create FP weight tensor
119 weight_ = top_[0] + "_wt";
120 tenWeight_ = new Tensor(weight_);
121 assert(tenWeight_ != NULL);
122 tenWeight_->setOwner(this);
123 tenWeight_->setType(CONVWEIGHT);
124
125 shape_setzero(&ws_);
126
127 ws_.ndims = ts_.ndims; // Number of dimesions
128 ws_.dims[0] = ms_.dims[1]; // Num output feature maps (from mid tensor)
129 ws_.dims[1] = bs->dims[1]; // Num input feature maps (from bottom tensor)
130 ws_.dims[2] = vd[0]; // Kernel height
131 ws_.dims[3] = vd[1]; // Kernel width
132
133 tenWeight_->setShape(&ws_);
134 tenWeight_->setBufDataType(DATA, DT_FLOAT);
135 tenWeightData_ = tenWeight_->getBuf(DATA);
136 tenWeightData_->setBufferType(DATA);
137
138 int welem = 1;
139 long long int wsize;
140 for(int i=0; i<ws_.ndims; i++)
141 welem = welem*ws_.dims[i];
142
143 // size of master weights -- FP32.
144 wsize = welem*sizeof(float);
145
146 gparams_.num_numa_nodes = NUM_NUMA_NODES;
147 tenWeightData_->setBufferSize(wsize);
148
149 wfiller_type_ = p->get_weight_filler_type();
150 variance_norm_ = p->get_variance_norm();
151 std_ = p->get_std();
152
153 lr_mult_ = p->get_lr_mult();
154 decay_mult_ = p->get_decay_mult();
155
156 Shape sss;
157 shape_setzero(&sss);
158 sss.ndims = 1;
159 sss.dims[0] = ts_.dims[1];
160
161 scale_ = top_[0] + "_scale";
162 tenScale_ = new Tensor(scale_);
163 assert(tenScale_ != NULL);
164 tenScale_->setOwner(this);
165 tenScale_->setType(BNORMSCALE);
166 tenScale_->setShape(&sss);
167 tenScaleData_ = tenScale_->getBuf(DATA);
168 tenScaleData_->setDataType(DT_FLOAT);
169 tenScaleData_->setBufferType(DATA);
170
171 telem = sss.dims[0];
172 tsize = telem*sizeof(float);
173 tenScaleData_->setBufferSize(tsize);
174
175 shift_ = top_[0] + "_shift";
176 tenShift_ = new Tensor(shift_);
177 assert(tenShift_ != NULL);
178 tenShift_->setOwner(this);
179 tenShift_->setType(BNORMSHIFT);
180 tenShift_->setShape(&sss);
181 tenShiftData_ = tenShift_->getBuf(DATA);
182 tenShiftData_->setDataType(DT_FLOAT);
183 tenShiftData_->setBufferType(DATA);
184
185 tenShiftData_->setBufferSize(tsize);
186
187 mean_ = top_[0] + "_mean";
188 tenMean_ = new Tensor(mean_);
189 assert(tenMean_ != NULL);
190 tenMean_->setOwner(this);
191 tenMean_->setType(BNORMMEAN);
192 tenMean_->setShape(&sss);
193 tenMeanData_ = tenMean_->getBuf(DATA);
194 tenMeanData_->setDataType(DT_FLOAT);
195 tenMeanData_->setBufferType(DATA);
196 tenMeanData_->setBufferSize(tsize);
197
198 var_ = top_[0] + "_var";
199 tenVar_ = new Tensor(var_);
200 assert(tenVar_ != NULL);
201 tenVar_->setOwner(this);
202 tenVar_->setType(BNORMVAR);
203 tenVar_->setShape(&sss);
204 tenVarData_ = tenVar_->getBuf(DATA);
205 tenVarData_->setDataType(DT_FLOAT);
206 tenVarData_->setBufferType(DATA);
207 tenVarData_->setBufferSize(tsize);
208
209 if(!e->is_inference_only()) {
210 if(bp_flag_)
211 {
212 tenBotDiff_.resize(bottom_.size());
213 for(int i=0; i<bottom_.size(); i++)
214 {
215 tenBotDiff_[i] = tenBot_[i]->addBuf(); // DIFF type and index
216 tenBotDiff_[i]->setDataType(in_dtype);
217 tenBotDiff_[i]->setBufferType(DIFF);
218
219 // Set the size of the input-gradient buffer
220 Shape *bs = tenBot_[i]->getShape();
221 int botelem = bs->dims[0] * bs->dims[1] * (bs->dims[2] + 2*ivp[0]) * (bs->dims[3] + 2*ivp[1]);
222 if(in_dtype == DT_FLOAT)
223 tenBotDiff_[i]->setBufferSize((botelem + convelem)*sizeof(float));
224 else if(in_dtype == DT_BF16)
225 tenBotDiff_[i]->setBufferSize((botelem + convelem)*sizeof(libxsmm_bfloat16));
226 }
227 tenMidDiff_ = tenMid_->addBuf(); // DIFF type and index
228 tenMidDiff_->setDataType(in_dtype);
229 tenMidDiff_->setBufferType(DIFF);
230 }
231
232 if(has_weights_)
233 {
234 if(tenMidDiff_ == NULL)
235 {
236 tenMidDiff_ = tenMid_->addBuf(); // DIFF type and index
237 tenMidDiff_->setDataType(in_dtype);
238 tenMidDiff_->setBufferType(DIFF);
239 if(in_dtype == DT_FLOAT)
240 tenMidDiff_->setBufferSize(convelem*sizeof(float));
241 else if(in_dtype == DT_BF16)
242 tenMidDiff_->setBufferSize(convelem*sizeof(libxsmm_bfloat16));
243 }
244
245 tenWeightDiff_ = tenWeight_->addBuf(); // DIFF type and index
246 tenWeightDiff_->setBufferType(DIFF);
247
248 tenWeightInc_ = tenWeight_->addBuf(); // SHARED type and index
249 tenWeightInc_->setDataType(DT_FLOAT);
250 tenWeightInc_->setBufferType(HISTORY);
251 tenWeightInc_->setBufferSize(welem*sizeof(float));
252
253 // Set the size of the weight-gradient buffer and the weight-increment buffer
254 if(in_dtype == DT_FLOAT)
255 {
256 tenWeightDiff_->setDataType(DT_FLOAT);
257 tenWeightDiff_->setBufferSize(welem*sizeof(float));
258 }
259 else if(in_dtype == DT_BF16)
260 {
261 tenWeightDiff_->setDataType(DT_BF16);
262 tenWeightDiff_->setBufferSize(welem*sizeof(libxsmm_bfloat16));
263 }
264
265 tenScaleDiff_ = tenScale_->addBuf();
266 tenScaleDiff_->setDataType(DT_FLOAT);
267 tenScaleDiff_->setBufferType(DIFF);
268 tenScaleDiff_->setBufferSize(tsize);
269
270 tenScaleInc_ = tenScale_->addBuf();
271 tenScaleInc_->setDataType(DT_FLOAT);
272 tenScaleInc_->setBufferType(HISTORY);
273 tenScaleInc_->setBufferSize(tsize);
274
275 tenShiftDiff_ = tenShift_->addBuf();
276 tenShiftDiff_->setDataType(DT_FLOAT);
277 tenShiftDiff_->setBufferType(DIFF);
278 tenShiftDiff_->setBufferSize(tsize);
279
280 tenShiftInc_ = tenShift_->addBuf();
281 tenShiftInc_->setDataType(DT_FLOAT);
282 tenShiftInc_->setBufferType(HISTORY);
283 tenShiftInc_->setBufferSize(tsize);
284 }
285 }
286 else {
287 tenMidDiff_ = NULL;
288 tenWeightDiff_ = NULL;
289 tenWeightInc_ = NULL;
290 tenScaleDiff_ = NULL;
291 tenShiftDiff_ = NULL;
292 tenScaleInc_ = NULL;
293 tenShiftInc_ = NULL;
294 }
295
296 // Register output tensor in tensor map
297 bool inserted = e->register_tensor(top_[0], ACT, tenTop_);
298 if(!inserted)
299 printf("Warning: Tensor %s already registered\n",top_[0].c_str());
300
301 string m = "mid_"+top_[0];
302 inserted = e->register_tensor(m, ACT, tenMid_);
303 if(!inserted)
304 printf("Warning: Tensor %s already registered\n",m.c_str());
305
306 // Register weight tensor in weight tensor map
307 inserted = e->register_tensor(weight_, CONVWEIGHT, tenWeight_);
308 if(!inserted)
309 printf("Warning: Tensor %s already registered\n",weight_.c_str());
310
311 inserted = e->register_tensor(scale_, BNORMSCALE, tenScale_);
312 if(!inserted)
313 printf("Warning: Tensor %s already registered\n",scale_.c_str());
314
315 inserted = e->register_tensor(shift_, BNORMSHIFT, tenShift_);
316 if(!inserted)
317 printf("Warning: Tensor %s already registered\n",shift_.c_str());
318
319 inserted = e->register_tensor(mean_, BNORMMEAN, tenMean_);
320 if(!inserted)
321 printf("Warning: Tensor %s already registered\n",mean_.c_str());
322
323 inserted = e->register_tensor(var_, BNORMVAR, tenVar_);
324 if(!inserted)
325 printf("Warning: Tensor %s already registered\n",var_.c_str());
326
327 // Setup parameter structure for convolution computation in library
328 gparams_.bdims = bs->ndims;
329 gparams_.tdims = ts_.ndims;
330 gparams_.mdims = ms_.ndims;
331 gparams_.wdims = ws_.ndims;
332
333 gparams_.node_name = nname_;
334 gparams_.node_type = ntype_;
335 gparams_.nInput.resize(bottom_.size());
336 if(bottom_.size() > 1)
337 gparams_.nInput.resize(bottom_.size());
338 gparams_.nInput[0] = bs->dims[1];
339 if(bottom_.size() > 1)
340 gparams_.nInput[1] = tenBot_[1]->getShape()->dims[1];
341 gparams_.nOutput = ts_.dims[1];
342 gparams_.batch_size = bs->dims[0];
343 gparams_.iHeight = bs->dims[2];
344 gparams_.iWidth = bs->dims[3];
345 gparams_.mHeight = ms_.dims[2];
346 gparams_.mWidth = ms_.dims[3];
347 gparams_.oHeight = ts_.dims[2];
348 gparams_.oWidth = ts_.dims[3];
349 gparams_.ipad_h = ivp[0];
350 gparams_.ipad_w = ivp[1];
351 gparams_.mpad_h = mvp[0];
352 gparams_.mpad_w = mvp[1];
353 gparams_.opad_h = ovp[0];
354 gparams_.opad_w = ovp[1];
355 gparams_.physical_padding = p->get_physical_padding();
356
357 gparams_.group = p->get_group();
358 gparams_.c_stride_h = vcs[0];
359 gparams_.c_stride_w = vcs[1];
360 gparams_.bn_stride_h = vbns[0];
361 gparams_.bn_stride_w = vbns[1];
362 gparams_.kh = ws_.dims[2];
363 gparams_.kw = ws_.dims[3];
364
365 gparams_.relu_fwd = p->get_relu_fwd();
366 gparams_.relu_bwd = p->get_relu_bwd();
367
368 gparams_.mmf = p->get_mmf();
369 gparams_.eps = p->get_eps();
370 gparams_.use_global_stats = p->get_global_stats_flag();
371 gparams_.eltwise = p->get_eltwise();
372 gparams_.bprop = bp_flag_;
373
374 gparams_.in_data_type = in_dtype;
375 gparams_.out_data_type = out_dtype;
376 gparams_.algType = p->get_algo_type();
377 gparams_.num_threads = e->get_num_threads();
378
379 // get solver
380 solver_ = e->getSolver();
381
382 //get global scratch tensor buffer
383 tenScratchData_ = e->getScratchBuffer();
384
385 // get engine
386 eptr_ = e;
387
388 #ifdef USE_MLSL
389 MLSL::DataType dt = MLSL::DT_FLOAT;
390 MLSL::OperationRegInfo *myRegInfo;
391 MLSL::Session *s = eptr_->get_session();
392 myRegInfo = s->CreateOperationRegInfo(MLSL::OT_CC);
393 myRegInfo->SetName(nname_.c_str());
394 myRegInfo->AddParameterSet(gparams_.nInput[0]*gparams_.nOutput/gparams_.group, gparams_.kw*gparams_.kh, dt, false);
395 myRegInfo->AddParameterSet(gparams_.nOutput, 1, dt, false);
396 myRegInfo->AddParameterSet(gparams_.nOutput, 1, dt, false);
397 myRegInfo->AddParameterSet(gparams_.nOutput, 1, dt, false);
398 myRegInfo->AddParameterSet(gparams_.nOutput, 1, dt, false);
399
400 myRegInfo->Validate();
401 size_t opIdx = s->AddOperation(myRegInfo, e->get_distribution());
402 this->op_ = s->GetOperation(opIdx);
403 s->DeleteOperationRegInfo(myRegInfo);
404 e->get_combo_grad_comms_vec().push_back(op_);
405 #endif
406
407 configure(p->get_compute_engine());
408 }
409
configure(int engine)410 void FusedConvBNNode::configure(int engine)
411 {
412 switch(engine)
413 {
414 case XSMM:
415 impl = new FusedConvBNXSMM(&gparams_, engine);
416 break;
417 }
418 }
419
fillWeightBuffers(TensorBuf * tBuf,int buftype,long long int size)420 void FusedConvBNNode::fillWeightBuffers(TensorBuf* tBuf, int buftype, long long int size)
421 {
422 int dtype = DT_FLOAT;
423 void *ptr = tBuf->getBuffer();
424
425 #ifdef USE_MLSL
426 unsigned int node_id = MLSL::Environment::GetEnv().GetProcessIdx();
427 #else
428 unsigned int node_id = 0;
429 #endif
430
431 int ic = gparams_.nInput[0];
432 int oc = gparams_.nOutput;
433 int kh = gparams_.kh;
434 int kw = gparams_.kw;
435 int g = gparams_.group;
436 int fanin = (ic * kh * kw)/g;
437 int fanout = (oc * kh * kw)/g;
438 int welem = ic * oc * kh * kw;
439
440 if(buftype == DATA)
441 {
442 if(node_id == 0)
443 initBuffer(ptr, variance_norm_, fanin, fanout, welem*sizeof(float), wfiller_type_, std_);
444
445 #ifdef USE_MLSL
446 if(dtype == DT_FLOAT)
447 MPI_Bcast(ptr, welem, MPI_FLOAT, 0, MPI_COMM_WORLD);
448 #endif
449 }
450 else if(buftype == HISTORY || buftype == DIFF)
451 memset(ptr, 0, size);
452 }
453
fillWeightMultipliers(float * lr,float * decay,long long int size)454 void FusedConvBNNode::fillWeightMultipliers(float* lr, float* decay, long long int size)
455 {
456 for(int i=0; i < size; i++)
457 {
458 lr[i] = lr_mult_[0];
459 decay[i] = decay_mult_[0];
460 }
461 }
462
fillBiasMultipliers(float * lr,float * decay,long long int size)463 void FusedConvBNNode::fillBiasMultipliers(float* lr, float* decay, long long int size)
464 {
465 for(int i=0; i < size; i++)
466 {
467 lr[i] = lr_mult_[1];
468 decay[i] = decay_mult_[1];
469 }
470 }
471
fillBuffer(TensorBuf * tBuf,int buftype,long long int size)472 void FusedConvBNNode::fillBuffer(TensorBuf* tBuf, int buftype, long long int size)
473 {
474 int ttype = tBuf->getTensor()->getType();
475 int dtype = DT_FLOAT;
476 void *ptr = tBuf->getBuffer();
477
478 if(ttype==BNORMSCALE && buftype == DATA)
479 {
480 if(nname_.find("bn3") == nname_.npos)
481 initConstantBuffer(ptr, size, "CONSTANT", 1.0f);
482 else
483 initConstantBuffer(ptr, size, "CONSTANT", 0.0f);
484 }
485 else
486 initConstantBuffer(ptr, size, "CONSTANT", 0.0f);
487 }
488
Checkpoint(TensorBuf * tBuf,string name,string format)489 void FusedConvBNNode::Checkpoint(TensorBuf *tBuf, string name, string format)
490 {
491 long long int bytes = tBuf->getBufferSize();
492 int dtype = tBuf->getDataType();
493
494 FILE* f;
495 void* ptr;
496 size_t pos;
497
498 if((name.find("30") == name.npos) && (name.find("60") == name.npos) && (name.find("80") == name.npos))
499 while((pos = name.find("/", 10)) != name.npos)
500 name.replace(pos, 1, 1, '_');
501
502 float* p = (float*)tBuf->getBuffer();
503 bool no_checkpt = false;
504 for(int i=0; i<16; i++)
505 {
506 if(isnan(p[i]) || isinf(p[i]))
507 {
508 no_checkpt = true;
509 printf("Warning! %s Did not checkpoint! Weights are NaNs or Inf\n", nname_.c_str());
510 break;
511 }
512 }
513
514 if(!no_checkpt)
515 {
516 if(format.compare("binary") == 0)
517 {
518 f = fopen(name.c_str(), "wb");
519 if(f != NULL)
520 {
521 if(name.find("wt") != name.npos)
522 {
523 ptr = _mm_malloc(bytes, 64);
524 assert(ptr != NULL);
525 impl->dumpBuffer(tBuf, ptr);
526 }
527 else if(name.find("mean") != name.npos || name.find("var") != name.npos)
528 ptr = tBuf->getPrivBuffer();
529 else
530 ptr = tBuf->getBuffer();
531
532 size_t b = fwrite(ptr, 1, bytes, f);
533 assert((long long int)b == bytes);
534
535 if(name.find("wt") != name.npos)
536 _mm_free(ptr);
537 }
538 else
539 printf("Warning: could not checkpoint to file %s\n",name.c_str());
540 }
541 else
542 {
543 f = fopen(name.c_str(), "w");
544 if(f != NULL)
545 {
546 if(name.find("wt") != name.npos)
547 {
548 ptr = _mm_malloc(bytes, 64);
549 assert(ptr != NULL);
550 impl->dumpBuffer(tBuf, ptr);
551 }
552 else
553 ptr = tBuf->getBuffer();
554
555 for(int i=0; i<bytes/sizeof(float); i++)
556 fprintf(f, "%f\n", *((float*)ptr + i));
557
558 if(name.find("wt") != name.npos)
559 _mm_free(ptr);
560 }
561 else
562 printf("Warning: could not checkpoint to file %s\n",name.c_str());
563 }
564 if(f != NULL)
565 {
566 fflush(f);
567 fclose(f);
568 }
569 }
570 }
571
convert_f32_bf16(float * in,libxsmm_bfloat16 * out,int len)572 void FusedConvBNNode::convert_f32_bf16(float* in, libxsmm_bfloat16* out, int len)
573 {
574 int i;
575
576 #ifdef _OPENMP
577 #pragma omp parallel for private(i)
578 #endif
579 for ( i = 0; i < len; i+=16 ) {
580 __m512 vfp32 = gxm_fp32_to_bfp16_rne_adjustment_avx512f(_mm512_loadu_ps(in + i));
581 __m256i vbfp16 = gxm_fp32_to_bfp16_truncate_avx512f(vfp32);
582 _mm256_storeu_si256( (__m256i*)(out+i), vbfp16 );
583 }
584 }
585
convert_bf16_f32(libxsmm_bfloat16 * in,float * out,int len)586 void FusedConvBNNode::convert_bf16_f32(libxsmm_bfloat16* in, float* out, int len)
587 {
588 int i;
589
590 #ifdef _OPENMP
591 #pragma omp parallel for private(i)
592 #endif
593 for ( i = 0; i < len; i+=16 ) {
594 __m256i vbfp16 = _mm256_loadu_si256( (const __m256i*)(in+i) );
595 __m512 vfp32 = gxm_bfp16_to_fp32_avx512f( vbfp16 );
596 _mm512_storeu_ps( out+i, vfp32 );
597 }
598 }
599
forwardPropagate()600 void FusedConvBNNode::forwardPropagate()
601 {
602 int nImg = gparams_.batch_size;
603 int ifm0 = gparams_.nInput[0];
604 int ifm1 = gparams_.eltwise ? gparams_.nInput[1] : 0;
605 int ofm = gparams_.nOutput;
606 int ifh = gparams_.iHeight;
607 int ifhp = ifh + 2*gparams_.ipad_h;
608 int ifw = gparams_.iWidth;
609 int ifwp = ifw + 2*gparams_.ipad_w;
610 int mfh = gparams_.mHeight;
611 int mfw = gparams_.mWidth;
612 int mfhp = mfh + 2*gparams_.mpad_h;
613 int mfwp = mfw + 2*gparams_.mpad_w;
614 int ofh = gparams_.oHeight;
615 int ofw = gparams_.oWidth;
616 int oph = gparams_.opad_h;
617 int opw = gparams_.opad_w;
618 int ofhp = ofh + 2*oph;
619 int ofwp = ofw + 2*opw;
620 int bnsh = gparams_.bn_stride_h;
621 int bnsw = gparams_.bn_stride_w;
622 int kh = gparams_.kh;
623 int kw = gparams_.kw;
624
625 #ifndef NDEBUG
626 // printf("Executing FP %s: input %p, weights %p, output %p\n",NNNode::nname_.c_str(), bot, wt, top);
627 printf("Executing FP %s\n",NNNode::nname_.c_str());
628 printf("Inputs: %d x %d x %d\n",ifm0, ifh, ifw);
629 printf("Outputs: %d x %d x %d\n",ofm, ofh, ofw);
630 printf("Weights: %d x %d x %d x %d\n", ifm0, ofm, kh, kw);
631 printf("Bias: %d\n", ofm);
632 #endif
633
634 if(first_fp)
635 {
636 impl->set_top_compute_engine(top_compute_engine_);
637 impl->set_bot_compute_engine(bot_cengine_);
638 impl->set_node_name(nname_);
639 impl->set_scratch_buffer(tenScratchData_);
640
641 if(eptr_->get_execution_mode() == TRAIN || eptr_->get_execution_mode() == VAL)
642 {
643 impl->set_global_stats(false);
644 gparams_.exec_mode = "TRAIN";
645 }
646 else if(eptr_->get_execution_mode() == TEST)
647 impl->set_global_stats(true);
648
649 tenMidData_->setBuffer(tenTopData_->getBuffer());
650
651 if(out_dtype == DT_FLOAT)
652 {
653 float* ptr = (float*)tenMidData_->getBuffer();
654 int size = nImg * ofm * mfhp * mfwp;
655 tenMidData_->setBufferSize(size*sizeof(float));
656 tenTopData_->setBuffer(tenTopData_->getBuffer() + size*sizeof(float));
657 tenTopData_->setBufferSize(tenTopData_->getBufferSize() - size*sizeof(float));
658
659 // NUMA initialize Conv output
660 #ifdef _OPENMP
661 #pragma omp parallel for
662 #endif
663 for(int i=0; i<size; i++)
664 ptr[i] = 0;
665
666 // NUMA initialize BN output
667 size = nImg * ofm * (ofh/bnsh +2*oph) * (ofw/bnsw + 2*opw);
668 ptr = (float*)tenTopData_->getBuffer();
669
670 #ifdef _OPENMP
671 #pragma omp parallel for
672 #endif
673 for(int i=0; i<size; i++)
674 ptr[i] = 0;
675 }
676 else if(out_dtype == DT_BF16)
677 {
678 libxsmm_bfloat16* ptr = (libxsmm_bfloat16*)tenMidData_->getBuffer();
679 int size = nImg * ofm * mfhp * mfwp;
680 tenMidData_->setBufferSize(size*sizeof(libxsmm_bfloat16));
681 tenTopData_->setBuffer(tenTopData_->getBuffer() + size*sizeof(libxsmm_bfloat16));
682 tenTopData_->setBufferSize(tenTopData_->getBufferSize() - size*sizeof(libxsmm_bfloat16));
683
684 // NUMA initialize Conv output
685 #ifdef _OPENMP
686 #pragma omp parallel for
687 #endif
688 for(int i=0; i<size; i++)
689 ptr[i] = 0;
690
691 // NUMA initialize BN output
692 ptr = (libxsmm_bfloat16*)tenTopData_->getBuffer();
693 size = nImg * ofm * (ofh/bnsh + 2*oph) * (ofw/bnsw + 2*opw);
694
695 #ifdef _OPENMP
696 #pragma omp parallel for
697 #endif
698 for(int i=0; i<size; i++)
699 ptr[i] = 0;
700 }
701
702 cbptr = (float*)_mm_malloc(10240*4, 64);
703 scf_ = eptr_->get_scaling_factor();
704 impl->set_scaling_factor(scf_);
705
706 first_fp = false;
707 }
708
709 impl->forwardPropagate(tenBotData_, tenWeightData_, tenWeightInc_, tenMidData_, tenScaleData_, tenShiftData_, tenMeanData_, tenVarData_, tenTopData_, 0);
710
711 if(eptr_->get_execution_mode() != TEST && eptr_->get_execution_mode() != VAL)
712 {
713 scf_ *= gparams_.mmf;
714 scf_ += 1.;
715
716 eptr_->set_scaling_factor(scf_);
717 }
718
719 #ifdef CHECK_BLOWUP_FP32
720 if(out_dtype == DT_FLOAT)
721 {
722 for(int i=0; i<10240; i++)
723 {
724 float v = ((float*)tenTopData_->getBuffer())[i];
725 if(isnan(v) || isinf(v))
726 {
727 printf("Warning! %s layer FP activations are NaN or Inf\n", nname_.c_str());
728 exit(-1);
729 }
730 }
731 }
732 else if(out_dtype == DT_BF16)
733 {
734 convert_bf16_f32((libxsmm_bfloat16*)tenMidData_->getBuffer(), cbptr, 10240);
735 for(int i=0; i<10240; i++)
736 {
737 if(isnan(cbptr[i]) || isinf(cbptr[i]))
738 {
739 printf("Warning! %s layer FP mid activations are NaN or Inf\n", nname_.c_str());
740 libxsmm_bfloat16 *ptr = (libxsmm_bfloat16*)tenMidData_->getBuffer();
741 printf("cbptr[%d] = %d, cbptr[%d] = %f\n",i,ptr[i],i,cbptr[i]);
742 exit(-1);
743 }
744 }
745 convert_bf16_f32((libxsmm_bfloat16*)tenTopData_->getBuffer(), cbptr, 10240);
746 for(int i=0; i<10240; i++)
747 {
748 if(isnan(cbptr[i]) || isinf(cbptr[i]))
749 {
750 printf("Warning! %s layer FP activations are NaN or Inf\n", nname_.c_str());
751 libxsmm_bfloat16 *ptr = (libxsmm_bfloat16*)tenTopData_->getBuffer();
752 printf("cbptr[%d] = %d, cbptr[%d] = %f\n",i,ptr[i],i,cbptr[i]);
753 exit(-1);
754 }
755 }
756 }
757 #endif
758
759 #ifdef GETSTATS
760 #ifdef USE_MLSL
761 unsigned int node_id = MLSL::Environment::GetEnv().GetProcessIdx();
762 #else
763 unsigned int node_id = 0;
764 #endif
765 if(node_id == 0)
766 {
767 if(in_dtype == DT_FLOAT)
768 {
769 float *ptr = (float*)tenBotData_[0]->getBuffer();
770 string s = nname_ + "_r_Inp";
771 MeanOfLayer((char*)s.c_str(), ptr, nImg*ifm0*ifhp*ifwp);
772
773 if(gparams_.nInput.size() > 1)
774 {
775 ptr = (float*)tenBotData_[1]->getBuffer();
776 s = nname_ + "_l_Inp";
777 MeanOfLayer((char*)s.c_str(), ptr, nImg*ifm1*ifhp*ifwp);
778 }
779
780 ptr = (float*)tenMidData_->getBuffer();
781 s = nname_ + "_mid";
782 MeanOfLayer((char*)s.c_str(), ptr, nImg*ofm*mfhp*mfwp);
783 }
784 else if(in_dtype == DT_BF16)
785 {
786 if(stptr == NULL)
787 {
788 int s = nImg*ofm*ofhp*ofwp;
789 int ms = nImg*ofm*mfhp*mfwp;
790 int is = nImg*ifm0*ifhp*ifwp;
791 int is1=0;
792 if(gparams_.nInput.size() > 1)
793 is1 = nImg*ifm1*ifhp*ifwp;
794
795 int size = s > ms ? s : ms;
796 size = size > is ? size : is;
797 size = size > is1 ? size : is1;
798
799 stptr = (float*)libxsmm_aligned_malloc(size*sizeof(float), 2097152);
800 }
801
802 libxsmm_bfloat16 *ptr;
803 if(tenBotData_[0]->getLPBuffer() != NULL)
804 ptr = (libxsmm_bfloat16*)tenBotData_[0]->getLPBuffer();
805 else
806 ptr = (libxsmm_bfloat16*)tenBotData_[0]->getBuffer();
807
808 string s = nname_ + "_r_Inp";
809 convert_bf16_f32(ptr, stptr, nImg*ifm0*ifhp*ifwp);
810 MeanOfLayer((char*)s.c_str(), stptr, nImg*ifm0*ifhp*ifwp);
811
812 if(gparams_.nInput.size() > 1)
813 {
814 if(tenBotData_[1]->getLPBuffer() != NULL)
815 ptr = (libxsmm_bfloat16*)tenBotData_[1]->getLPBuffer();
816 else
817 ptr = (libxsmm_bfloat16*)tenBotData_[1]->getBuffer();
818
819 convert_bf16_f32(ptr, stptr, nImg*ifm1*ifhp*ifwp);
820 s = nname_ + "_l_Inp";
821 MeanOfLayer((char*)s.c_str(), stptr, nImg*ifm1*ifhp*ifwp);
822 }
823
824 ptr = (libxsmm_bfloat16*)tenMidData_->getBuffer();
825 convert_bf16_f32(ptr, stptr, nImg*ofm*mfhp*mfwp);
826 s = nname_ + "_mid";
827 MeanOfLayer((char*)s.c_str(), stptr, nImg*ofm*mfhp*mfwp);
828 }
829
830 string s = nname_ + "_wt";
831 float* wt = (float*)tenWeightData_->getBuffer();
832 MeanOfLayer((char*)s.c_str(), wt, ifm0*ofm*kh*kw);
833
834 s = nname_ + "_gammap";
835 float* gamma = (float*)tenScaleData_->getBuffer();
836 MeanOfLayer((char*)s.c_str(), gamma, gparams_.nOutput);
837
838 s = nname_ + "_betap";
839 float* beta = (float*)tenShiftData_->getBuffer();
840 MeanOfLayer((char*)s.c_str(), beta, gparams_.nOutput);
841
842 if(out_dtype == DT_FLOAT)
843 {
844 float *ptr = (float*)tenTopData_->getBuffer();
845 string s = nname_ + "_Outp";
846 int size = nImg*ofm*(ofh/bnsh + 2*oph)*(ofw/bnsw + 2*opw);
847 MeanOfLayer((char*)s.c_str(), ptr, size);
848 }
849 else if(out_dtype == DT_BF16)
850 {
851 libxsmm_bfloat16 *ptr = (libxsmm_bfloat16*)tenTopData_->getBuffer();
852 s = nname_ + "_Outp";
853 int size = nImg*ofm*(ofh/bnsh + 2*oph)*(ofw/bnsw + 2*opw);
854 convert_bf16_f32(ptr, stptr, size);
855 MeanOfLayer((char*)s.c_str(), stptr, size);
856 }
857 }
858 #endif
859 }
860
backPropagate()861 void FusedConvBNNode::backPropagate()
862 {
863
864 int nImg = gparams_.batch_size;
865 int ifm0 = gparams_.nInput[0];
866 int ifm1 = gparams_.eltwise ? gparams_.nInput[1] : 0;
867 int ofm = gparams_.nOutput;
868 int ifh = gparams_.iHeight;
869 int ifhp = ifh + 2*gparams_.ipad_h;
870 int ifw = gparams_.iWidth;
871 int ifwp = ifw + 2*gparams_.ipad_w;
872 int mfh = gparams_.mHeight;
873 int mfw = gparams_.mWidth;
874 int mfhp = mfh + 2*gparams_.mpad_h;
875 int mfwp = mfw + 2*gparams_.mpad_w;
876 int ofh = gparams_.oHeight;
877 int ofw = gparams_.oWidth;
878 int ofhp = ofh + 2*gparams_.opad_h;
879 int ofwp = ofw + 2*gparams_.opad_w;
880 int kh = gparams_.kh;
881 int kw = gparams_.kw;
882
883 #ifdef DEBUG
884 printf("Executing BP %s\n",NNNode::nname_.c_str());
885 printf("Grad Outputs: %d x %d x %d\n", ofm, ofh, ofw);
886 printf("Grad Inputs: %d x %d x %d\n", ifm, ifh, ifw);
887 printf("Weights: %d x %d x %d x %d\n", ofm, ifm, kh, kw);
888 #endif
889
890 tenTopDiff_ = tenTop_->getBuf(DIFF);
891
892 if(first_bp)
893 {
894 int bsize0 = nImg*ifm0*ifhp*ifwp;
895 int bsize1 = nImg*ifm1*ifhp*ifwp;
896 int msize = nImg*ofm*mfhp*mfwp;
897
898 if(in_dtype == DT_FLOAT)
899 {
900 float* ptr = (float*)tenBotDiff_[0]->getBuffer();
901 tenMidDiff_->setBuffer(tenBotDiff_[0]->getBuffer() + bsize0*sizeof(float));
902 tenMidDiff_->setBufferSize(msize*sizeof(float));
903 tenBotDiff_[0]->setBufferSize(bsize0*sizeof(float));
904 if(gparams_.eltwise)
905 tenBotDiff_[1]->setBufferSize(bsize1*sizeof(float));
906
907 // NUMA initialize Conv delinp
908 #ifdef _OPENMP
909 #pragma omp parallel for
910 #endif
911 for(int i=0; i<bsize0; i++)
912 ptr[i] = 0;
913
914 // NUMA initialize BN delinp = Conv delmidp
915 ptr = (float*)tenMidDiff_->getBuffer();
916
917 #ifdef _OPENMP
918 #pragma omp parallel for
919 #endif
920 for(int i=0; i<msize; i++)
921 ptr[i] = 0;
922
923 ptr = gparams_.eltwise ? (float*)tenBotDiff_[1]->getBuffer() : NULL;
924 if(ptr)
925 {
926 #ifdef _OPENMP
927 #pragma omp parallel for
928 #endif
929 for(int i=0; i<bsize1; i++)
930 ptr[i] = 0;
931 }
932 }
933 else if(in_dtype == DT_BF16)
934 {
935 libxsmm_bfloat16* ptr = (libxsmm_bfloat16*)tenBotDiff_[0]->getBuffer();
936 tenMidDiff_->setBuffer(tenBotDiff_[0]->getBuffer() + bsize0*sizeof(libxsmm_bfloat16));
937 tenMidDiff_->setBufferSize(msize*sizeof(libxsmm_bfloat16));
938 tenBotDiff_[0]->setBufferSize(bsize0*sizeof(libxsmm_bfloat16));
939 if(gparams_.eltwise)
940 tenBotDiff_[1]->setBufferSize(bsize1*sizeof(libxsmm_bfloat16));
941
942 // NUMA initialize Conv delinp
943 #ifdef _OPENMP
944 #pragma omp parallel for
945 #endif
946 for(int i=0; i<bsize0; i++)
947 ptr[i] = 0;
948
949 // NUMA initialize BN delinp = Conv delmidp
950 ptr = (libxsmm_bfloat16*)tenMidDiff_->getBuffer();
951
952 #ifdef _OPENMP
953 #pragma omp parallel for
954 #endif
955 for(int i=0; i<msize; i++)
956 ptr[i] = 0;
957
958 ptr = gparams_.eltwise ? (libxsmm_bfloat16*)tenBotDiff_[1]->getBuffer() : NULL;
959 if(ptr)
960 {
961 #ifdef _OPENMP
962 #pragma omp parallel for
963 #endif
964 for(int i=0; i<bsize1; i++)
965 ptr[i] = 0;
966 }
967 }
968 first_bp = false;
969 }
970
971 impl->backPropagate(tenTopDiff_, tenWeightData_, tenScaleDiff_, tenShiftDiff_, tenMidDiff_, tenBotDiff_, 0);
972
973 #ifdef CHECK_BLOWUP_FP32
974 float* cbptr = (float*)tenTopDiff_->getBuffer();
975 for(int i=0; i<10240; i++)
976 {
977 if(isnan(cbptr[i]) || isinf(cbptr[i]))
978 {
979 printf("Warning! %s layer BP activations are NaN or Inf\n", nname_.c_str());
980 exit(-1);
981 }
982 }
983 #endif
984
985 #ifdef GETSTATS
986 float *ptr, *pptr, *p, *bias;
987 #ifdef USE_MLSL
988 unsigned int node_id_ = MLSL::Environment::GetEnv().GetProcessIdx();
989 #else
990 unsigned int node_id_ = 0;
991 #endif
992 if(node_id_ == 0)
993 {
994 int sh = gparams_.bn_stride_h;
995 int sw = gparams_.bn_stride_w;
996 int ph = gparams_.opad_h;
997 int pw = gparams_.opad_w;
998
999 if(out_dtype == DT_FLOAT)
1000 {
1001 float *ptr = (float*)tenTopDiff_->getBuffer();
1002
1003 int size = nImg*ofm*ofhp*ofwp;
1004 string s = nname_ + "_delOutp";
1005 MeanOfLayer((char*)s.c_str(), ptr, size);
1006 }
1007 else if(out_dtype == DT_BF16)
1008 {
1009 libxsmm_bfloat16 *ptr = (libxsmm_bfloat16*)tenTopDiff_->getBuffer();
1010 int size = nImg*ofm*ofhp*ofwp;
1011 convert_bf16_f32(ptr, stptr, size);
1012 string s = nname_ + "_delOutp";
1013 MeanOfLayer((char*)s.c_str(), stptr, size);
1014 }
1015
1016 string s = nname_ + "_delgammap";
1017 float* delgamma = (float*)tenScaleDiff_->getBuffer();
1018 MeanOfLayer((char*)s.c_str(), delgamma, gparams_.nOutput);
1019
1020 s = nname_ + "_delbetap";
1021 float* delbeta = (float*)tenShiftDiff_->getBuffer();
1022 MeanOfLayer((char*)s.c_str(), delbeta, gparams_.nOutput);
1023
1024 if(in_dtype == DT_FLOAT)
1025 {
1026 float *ptr = (float*)tenBotDiff_[0]->getBuffer();
1027 string s = nname_ + "_delInp";
1028 int size = nImg*ifm0*ifhp*ifwp;
1029 MeanOfLayer((char*)s.c_str(), ptr, size);
1030 }
1031 else if(in_dtype == DT_BF16)
1032 {
1033 libxsmm_bfloat16 *ptr = (libxsmm_bfloat16*)tenBotDiff_[0]->getBuffer();
1034 s = nname_ + "_delInp";
1035 int size = nImg*ifm0*ifhp*ifwp;
1036 convert_bf16_f32(ptr, stptr, size);
1037 MeanOfLayer((char*)s.c_str(), stptr, size);
1038 }
1039 }
1040 #endif
1041 }
1042
weightUpdate()1043 void FusedConvBNNode::weightUpdate()
1044 {
1045 int nImg = gparams_.batch_size;
1046 int ifm0 = gparams_.nInput[0];
1047 int ofm = gparams_.nOutput;
1048 int ifh = gparams_.iHeight;
1049 int ifw = gparams_.iWidth;
1050 int mfh = gparams_.mHeight;
1051 int mfw = gparams_.mWidth;
1052 int mfhp = mfh + 2*gparams_.mpad_h;
1053 int mfwp = mfw + 2*gparams_.mpad_w;
1054 int ifhp = ifh + 2*gparams_.ipad_h;
1055 int ifwp = ifw + 2*gparams_.ipad_w;
1056 int kh = gparams_.kh;
1057 int kw = gparams_.kw;
1058
1059 #ifdef DEBUG
1060 // printf("Executing WU %s: grad_output %p, grad_weights %p, input %p\n",NNNode::nname_.c_str(), gtop, gwt, bot);
1061 printf("Executing WU %s\n",NNNode::nname_.c_str());
1062 printf("Grad Outputs: %d x %d x %d\n",ofm, ofh,ofw);
1063 printf("Inputs: %d x %d x %d\n",ifm0, ifh, ifw);
1064 printf("del-Weights: %d x %d x %d x %d\n", ofm, ifm0, kh, kw);
1065 printf("del-Biases: %d\n", ofm);
1066 #endif
1067
1068 #ifdef GETSTATS
1069 #ifdef USE_MLSL
1070 int node_id = MLSL::Environment::GetEnv().GetProcessIdx();
1071 #else
1072 int node_id = 0;
1073 #endif
1074 if(node_id == 0)
1075 {
1076 if(in_dtype == DT_FLOAT)
1077 {
1078 string s = nname_ + "_delWt_Bef";
1079 float *ptr = (float*)tenWeightDiff_->getBuffer();
1080 MeanOfLayer((char*)s.c_str(), ptr, ifm0*ofm*kh*kw);
1081 }
1082 else if(in_dtype == DT_BF16)
1083 {
1084 string s = nname_ + "_delWt_Bef";
1085 libxsmm_bfloat16 *ptr = (libxsmm_bfloat16*)tenWeightDiff_->getBuffer();
1086 memset(stptr, 0, ifm0*ofm*kh*kw);
1087 convert_bf16_f32(ptr, stptr, ifm0*ofm*kh*kw);
1088 MeanOfLayer((char*)s.c_str(), stptr, ifm0*ofm*kh*kw);
1089 }
1090 }
1091 #endif
1092
1093 if(!bp_flag_ && first_upd)
1094 {
1095 int msize = nImg*ofm*mfhp*mfwp;
1096
1097 if(in_dtype == DT_FLOAT)
1098 {
1099 float *ptr = (float*)tenMidDiff_->getBuffer();
1100
1101 // NUMA initialize Conv delmidp
1102 #ifdef _OPENMP
1103 #pragma omp parallel for
1104 #endif
1105 for(int i=0; i<msize; i++)
1106 ptr[i] = 0;
1107 }
1108 else if(in_dtype == DT_BF16)
1109 {
1110 libxsmm_bfloat16* ptr = (libxsmm_bfloat16*)tenMidDiff_->getBuffer();
1111
1112 // NUMA initialize = Conv delmidp
1113 ptr = (libxsmm_bfloat16*)tenMidDiff_->getBuffer();
1114
1115 #ifdef _OPENMP
1116 #pragma omp parallel for
1117 #endif
1118 for(int i=0; i<msize; i++)
1119 ptr[i] = 0;
1120 }
1121 first_upd = false;
1122 }
1123
1124 tenTopDiff_ = tenTop_->getBuf(DIFF);
1125 impl->weightUpdate(tenBotData_[0], tenTopDiff_, tenMidDiff_, tenWeightDiff_, tenScaleDiff_, tenShiftDiff_, 0);
1126
1127 #ifdef CHECK_BLOWUP_FP32
1128 if(out_dtype == DT_FLOAT)
1129 {
1130 for(int i=0; i<16; i++)
1131 {
1132 float v = ((float*)tenWeightDiff_->getBuffer())[i];
1133 if(isnan(v) || isinf(v))
1134 {
1135 printf("Warning! %s layer BP activations are NaN or Inf\n", nname_.c_str());
1136 exit(-1);
1137 }
1138 }
1139 }
1140 else if(out_dtype == DT_BF16)
1141 {
1142 convert_bf16_f32((libxsmm_bfloat16*)tenWeightDiff_->getBuffer(), cbptr, 16);
1143 for(int i=0; i<16; i++)
1144 {
1145 if(isnan(cbptr[i]) || isinf(cbptr[i]))
1146 {
1147 printf("Warning! %s layer BP activations are NaN or Inf\n", nname_.c_str());
1148 exit(-1);
1149 }
1150 }
1151 }
1152 #endif
1153
1154 void* gexp[NUM_NUMA_NODES];
1155 void* gvar[NUM_NUMA_NODES];
1156 void* gexp_test = tenMeanData_->getPrivBuffer();
1157 void* gvar_test = tenVarData_->getPrivBuffer();
1158
1159 void **mptrptr = tenMeanData_->getBufferPtr();
1160 void **vptrptr = tenVarData_->getBufferPtr();
1161 int offset = tenMeanData_->getOffset();
1162
1163 for(int n=0; n<NUM_NUMA_NODES; n++)
1164 gexp[n] = mptrptr[n] + offset*sizeof(float);
1165
1166 offset = tenVarData_->getOffset();
1167 for(int n=0; n<NUM_NUMA_NODES; n++)
1168 gvar[n] = vptrptr[n] + offset*sizeof(float);
1169
1170 #ifdef USE_MLSL
1171 void *mptr = tenWeightDiff_->getBuffer();
1172
1173 if(in_dtype == DT_BF16)
1174 {
1175 if(dwptr == NULL)
1176 {
1177 int wsize = ifm0*ofm*kh*kw*sizeof(float);
1178 dwptr = (float*)MLSL::Environment::GetEnv().Alloc(wsize, 2097152);
1179 }
1180 convert_bf16_f32((libxsmm_bfloat16*)mptr, dwptr, ifm0*ofm*kh*kw);
1181 op_->GetParameterSet(0)->StartGradientComm(dwptr);
1182 }
1183 else if(in_dtype == DT_FLOAT)
1184 op_->GetParameterSet(0)->StartGradientComm(mptr);
1185
1186 op_->GetParameterSet(1)->StartGradientComm(tenScaleDiff_->getBuffer());
1187 op_->GetParameterSet(2)->StartGradientComm(tenShiftDiff_->getBuffer());
1188
1189 int num_nodes = eptr_->get_num_machines();
1190 for(int i=0; i<ofm; i++)
1191 {
1192 float mtmp = 0.0;
1193 float vtmp = 0.0;
1194
1195 for(int n=0; n<NUM_NUMA_NODES; n++)
1196 {
1197 mtmp += ((float*)gexp[n])[i];
1198 vtmp += ((float*)gvar[n])[i];
1199 }
1200
1201 mtmp = mtmp/NUM_NUMA_NODES;
1202 vtmp = vtmp/NUM_NUMA_NODES;
1203
1204 ((float*)gexp_test)[i] = mtmp/num_nodes;
1205 ((float*)gvar_test)[i] = vtmp/num_nodes;
1206 }
1207 this->op_->GetParameterSet(3)->StartGradientComm(gexp_test);
1208 this->op_->GetParameterSet(4)->StartGradientComm(gvar_test);
1209 #endif
1210
1211 #ifdef GETSTATS
1212 #ifdef USE_MLSL
1213 node_id = MLSL::Environment::GetEnv().GetProcessIdx();
1214 #else
1215 node_id = 0;
1216 #endif
1217 if(node_id == 0)
1218 {
1219 if(in_dtype == DT_FLOAT)
1220 {
1221 string s = nname_ + "_Inp";
1222 float *ptr = (float*)tenBotData_[0]->getBuffer();
1223 MeanOfLayer((char*)s.c_str(), ptr, nImg*ifm0*ifhp*ifwp);
1224
1225 s = nname_ + "_delMidp";
1226 ptr = (float*)tenMidDiff_->getBuffer();
1227 MeanOfLayer((char*)s.c_str(), ptr, nImg*ofm*mfhp*mfwp);
1228
1229 s = nname_ + "_delWt_Aft";
1230 ptr = (float*)tenWeightDiff_->getBuffer();
1231 float *pptr = (float*)tenWeightDiff_->getPrivBuffer();
1232 float *p = (pptr == NULL) ? ptr : pptr;
1233 MeanOfLayer((char*)s.c_str(), p, ifm0*ofm*kh*kw);
1234 }
1235 else if(in_dtype == DT_BF16)
1236 {
1237 string s = nname_ + "_Inp";
1238 libxsmm_bfloat16 *ptr = (libxsmm_bfloat16*)tenBotData_[0]->getBuffer();
1239 memset(stptr, 0, nImg*ifm0*ifhp*ifwp);
1240 convert_bf16_f32(ptr, stptr, nImg*ifm0*ifhp*ifwp);
1241 MeanOfLayer((char*)s.c_str(), stptr, nImg*ifm0*ifhp*ifwp);
1242
1243 s = nname_ + "_delMidp";
1244 ptr = (libxsmm_bfloat16*)tenMidDiff_->getBuffer();
1245 memset(stptr, 0, nImg*ofm*mfhp*mfwp);
1246 convert_bf16_f32(ptr, stptr, nImg*ofm*mfhp*mfwp);
1247 MeanOfLayer((char*)s.c_str(), stptr, nImg*ofm*mfhp*mfwp);
1248
1249 s = nname_ + "_delWt_Aft";
1250 #ifdef USE_MLSL
1251 MeanOfLayer((char*)s.c_str(), dwptr, ifm0*ofm*kh*kw);
1252 #else
1253 ptr = (libxsmm_bfloat16*)tenWeightDiff_->getBuffer();
1254 memset(stptr, 0, ifm0*ofm*kh*kw);
1255 convert_bf16_f32(ptr, stptr, ifm0*ofm*kh*kw);
1256 MeanOfLayer((char*)s.c_str(), stptr, ifm0*ofm*kh*kw);
1257 #endif
1258 }
1259 }
1260 #endif
1261 }
1262
solverStep()1263 void FusedConvBNNode::solverStep()
1264 {
1265 #ifdef USE_MLSL
1266 int ifm = gparams_.nInput[0];
1267 int ofm = gparams_.nOutput;
1268 int kh = gparams_.kh;
1269 int kw = gparams_.kw;
1270
1271 float *gwt = (float*)(tenWeightDiff_->getBuffer());
1272 float *delgamma = (float*)tenScaleDiff_->getBuffer();
1273 float *delbeta = (float*)tenShiftDiff_->getBuffer();
1274 void* gexp_test = tenMeanData_->getPrivBuffer();
1275 void* gvar_test = tenVarData_->getPrivBuffer();
1276
1277 int wsize = ifm*ofm*kh*kw;
1278
1279 void *mptr = op_->GetParameterSet(0)->WaitGradientComm();
1280 if(in_dtype == DT_FLOAT)
1281 {
1282 if(mptr != NULL && mptr != gwt)
1283 memcpy((void*)gwt, mptr, wsize*sizeof(float));
1284 }
1285 else if(in_dtype == DT_BF16)
1286 {
1287 if(mptr != NULL && mptr != dwptr)
1288 memcpy((void*)dwptr, mptr, wsize*sizeof(float));
1289 convert_f32_bf16(dwptr, (libxsmm_bfloat16*)gwt, wsize);
1290 }
1291
1292 mptr = op_->GetParameterSet(1)->WaitGradientComm();
1293 if(mptr != NULL && mptr != delgamma)
1294 memcpy((void*)delgamma, mptr, ofm*sizeof(float));
1295
1296 mptr = op_->GetParameterSet(2)->WaitGradientComm();
1297 if(mptr != NULL && mptr != delbeta)
1298 memcpy((void*)delbeta, mptr, ofm*sizeof(float));
1299
1300 mptr = op_->GetParameterSet(3)->WaitGradientComm();
1301 if(mptr != NULL && mptr != gexp_test)
1302 memcpy((void*)gexp_test, mptr, ofm*sizeof(float));
1303
1304 mptr = op_->GetParameterSet(4)->WaitGradientComm();
1305 if(mptr != NULL && mptr != gvar_test)
1306 memcpy((void*)gvar_test, mptr, ofm*sizeof(float));
1307
1308 #ifdef CHECK_BLOWUP_FP32
1309 float* ptr = (float*)tenWeightDiff_->getBuffer();
1310 for(int i=0; i<16; i++)
1311 {
1312 if(isnan(ptr[i]) || isinf(ptr[i]))
1313 {
1314 printf("Warning! %s layer Solver gradients are NaN or Inf\n", nname_.c_str());
1315 exit(-1);
1316 }
1317 }
1318 for(int i=0; i<16; i++)
1319 {
1320 if(isnan(delgamma[i]) || isinf(delgamma[i]))
1321 {
1322 printf("Warning! %s layer Solver gamma gradients are NaN or Inf\n", nname_.c_str());
1323 exit(-1);
1324 }
1325 }
1326 for(int i=0; i<16; i++)
1327 {
1328 if(isnan(delbeta[i]) || isinf(delbeta[i]))
1329 {
1330 printf("Warning! %s layer Solver beta gradients are NaN or Inf\n", nname_.c_str());
1331 exit(-1);
1332 }
1333 }
1334 #endif
1335 #endif
1336 }
1337