1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved.                      *
3 * This file is part of the LIBXSMM library.                                   *
4 *                                                                             *
5 * For information on the license, see the LICENSE file.                       *
6 * Further information: https://github.com/hfp/libxsmm/                        *
7 * SPDX-License-Identifier: BSD-3-Clause                                       *
8 ******************************************************************************/
9 /* Sasikanth Avancha, Dhiraj Kalamkar (Intel Corp.)
10 ******************************************************************************/
11 
12 
13 #include <map>
14 #include "assert.h"
15 #include "proto/gxm.pb.h"
16 #include "Node.hpp"
17 #include "Engine.hpp"
18 #include "Conv.hpp"
19 #include "FullyConnected.hpp"
20 #include "FusedBNorm.hpp"
21 #include "FusedConvBN.hpp"
22 #include "DummyData.hpp"
23 #include "TypeList.hpp"
24 
25 #include "unistd.h"
26 #include "limits.h"
27 
28 #define VLEN 16
29 
30 using namespace std;
31 using namespace gxm;
32 
33 int iter=0;
34 
compare_task_bins(Task * first,Task * second)35 bool compare_task_bins(Task* first, Task* second)
36 {
37   return (first->getMaxBin() < second->getMinBin());
38 }
39 
create_schedule(int mode)40 void MLEngine::create_schedule(int mode)
41 {
42   for(auto it=etg_[mode].begin(); it != etg_[mode].end(); it++)
43   {
44     Task* t = *it;
45     vector<Task*> tp = t->getBackDepTasks();
46     for(int i=0; i<tp.size(); i++) {
47       string s = dynamic_cast<NNNode*>(tp[i]->getNode())->getNodeName();
48 
49       if(tp[i]->getBasicTaskId() == BASIC_TASK_FORW) {
50         int maxbin = tp[i]->getMaxBin();
51         if((maxbin == 0) || (maxbin > t->getMinBin()-1))
52         {
53           tp[i]->setMinBin(t->getMaxBin() - 1);
54           tp[i]->setMaxBin(t->getMaxBin() - 1);
55           etg_[mode].push_back(tp[i]);
56 #ifdef DEBUG
57           printf("FP task %p (node %s), with bin %d pushed to etg_\n",tp[i], s.c_str(), tp[i]->getMaxBin());
58 #endif
59         }
60       }
61     }
62   }
63 
64   if(mode == TRAIN)
65   {
66     for(auto it=etg_[mode].begin(); it != etg_[mode].end(); it++)
67     {
68       Task* t = *it;
69       vector<Task*> tp = t->getForwDepTasks();
70       for(int i=0; i<tp.size(); i++)
71       {
72         string s = dynamic_cast<NNNode*>(tp[i]->getNode())->getNodeName();
73 
74         if(tp[i]->getBasicTaskId() != BASIC_TASK_FORW)
75         {
76           int maxbin = tp[i]->getMaxBin();
77           if((maxbin == 0) || (maxbin < t->getMinBin()+1))
78           {
79             tp[i]->setMinBin(t->getMaxBin() + 1);
80             tp[i]->setMaxBin(t->getMaxBin() + 1);
81             etg_[mode].push_back(tp[i]);
82 #ifdef DEBUG
83             if(tp[i]->getBasicTaskId() == BASIC_TASK_BACK)
84               printf("BP task %p (node %s), with bin %d pushed to etg_\n",tp[i], s.c_str(), tp[i]->getMaxBin());
85             else if(tp[i]->getBasicTaskId() == BASIC_TASK_WGRAD)
86               printf("WU task %p (node %s), with bin %d pushed to etg_\n",tp[i], s.c_str(), tp[i]->getMaxBin());
87             else if(tp[i]->getBasicTaskId() == BASIC_TASK_SOLVE)
88               printf("SOLVE task %p (node %s), with bin %d pushed to etg_\n",tp[i], s.c_str(), tp[i]->getMaxBin());
89 #endif
90           }
91         }
92       }
93     }
94   }
95 }
96 
find_in_nodeTypeList(string name)97 int MLEngine::find_in_nodeTypeList(string name)
98 {
99   for(int i=0; i<numTypes; i++)
100     if(nodeTypes[i].typeName.compare(name) == 0)
101       return i;
102   return -1;
103 }
104 
register_tensor(string name,int type,Tensor * t)105 bool MLEngine::register_tensor(string name, int type, Tensor* t)
106 {
107   TensorPair tp;
108   tp.name = name;
109   tp.t = t;
110 
111   Iter it;
112 
113   switch(type)
114   {
115     case INPUT:
116     case LABEL:
117       it = inTList_.insert(inTList_.end(), tp);
118       inTensorMap_[name] = it;
119       break;
120 
121     case ACT:
122       it = outTList_.insert(outTList_.end(), tp);
123       outTensorMap_[name] = it;
124       break;
125 
126     case CONVWEIGHT:
127     case FCWEIGHT:
128       it = wTList_.insert(wTList_.end(), tp);
129       weightTensorMap_[name] = it;
130       break;
131 
132     case CONVBIAS:
133     case FCBIAS:
134     case BNORMSCALE:
135     case BNORMSHIFT:
136       it = biasTList_.insert(biasTList_.end(), tp);
137       biasTensorMap_[name] = it;
138       break;
139 
140     case BNORMMEAN:
141     case BNORMVAR:
142       it = statsTList_.insert(statsTList_.end(), tp);
143       statsTensorMap_[name] = it;
144       break;
145   }
146   return true;
147 }
148 
get_tensor(string name,int type)149 Tensor* MLEngine::get_tensor(string name, int type)
150 {
151   Iter it = defTList_.end();
152 
153   switch(type)
154   {
155     case INPUT:
156     case LABEL:
157       it = inTensorMap_[name];
158       break;
159 
160     case ACT:
161       it = outTensorMap_[name];
162       break;
163 
164     case CONVWEIGHT:
165     case FCWEIGHT:
166       it = weightTensorMap_[name];
167       break;
168 
169     case CONVBIAS:
170     case FCBIAS:
171     case BNORMSCALE:
172     case BNORMSHIFT:
173       it = biasTensorMap_[name];
174       break;
175 
176     case BNORMMEAN:
177     case BNORMVAR:
178       it = statsTensorMap_[name];
179       break;
180   }
181 
182   if(it == defTList_.end())
183     return NULL;
184 
185   TensorPair tp = *it;
186   return tp.t;
187 }
188 
optimize_schedule(int mode)189 void MLEngine::optimize_schedule(int mode)
190 {
191   etg_[mode].sort(compare_task_bins);
192   etg_[mode].erase(std::stable_partition(etg_[mode].begin(), etg_[mode].end(), dupChecker_()), etg_[mode].end());
193   etg_[mode].unique();
194 }
195 
clear_history(TensorList L)196 void MLEngine::clear_history(TensorList L)
197 {
198   int buftype = HISTORY;
199 
200   for(Iter it=L.begin(); it != L.end(); it++)
201   {
202     Tensor* t = it->t;
203     TensorBuf *tBuf;
204     bool found = false;
205     for(int index=0; index<t->getNumDataBuffers(); index++)
206     {
207       tBuf = t->getBuf(index);
208       if(tBuf->getBufferType() == buftype)
209       {
210         found = true;
211         break;
212       }
213     }
214     if(!found) continue;
215 
216     long long int bytes = tBuf->getBufferSize();
217     int dtype = tBuf->getDataType();
218 
219     float *fp = (float*)(tBuf->getBuffer());
220 #ifdef _OPENMP
221 #pragma omp parallel for
222 #endif
223     for(int i=0; i<bytes/sizeof(float); i++)
224       fp[i] = 0.f;
225   }
226 }
227 
checkpoint(TensorList L,int buftype)228 void MLEngine::checkpoint(TensorList L, int buftype)
229 {
230   for(Iter it=L.begin(); it != L.end(); it++)
231   {
232     Tensor* t = it->t;
233     TensorBuf *tBuf;
234     bool found=false;
235 
236     for(int index=0; index<t->getNumDataBuffers(); index++)
237     {
238       tBuf = t->getBuf(index);
239       if(tBuf->getBufferType() == buftype)
240       {
241         found = true;
242         break;
243       }
244     }
245     if(!found) continue;
246 
247     int tenType = t->getType();
248     string tn = t->getTensorName();
249     string n = checkpoint_dir_ + "/" + tn;
250     if(buftype == HISTORY)
251       n = n + "_history";
252     else if(buftype == DIFF)
253       n = n + "_grad";
254 
255     string nntype = dynamic_cast<NNNode*>(t->getOwner())->getNodeType();
256 
257     if(current_epoch_ == 30 || current_epoch_ == 60 || current_epoch_ == 80)
258     {
259       if(tenType == ACT)
260       {
261         n = checkpoint_dir_ + to_string(current_epoch_) + "/" + tn;
262         if(tn.find("bn") != tn.npos)
263         {
264           if(nntype == "FusedBatchNorm")
265           {
266             FusedBNormNode* bn = dynamic_cast<FusedBNormNode*>(t->getOwner());
267             bn->Checkpoint(tBuf, n, checkpoint_format_);
268           }
269           else if(nntype == "FusedConvBN")
270           {
271             FusedConvBNNode* fcbn = dynamic_cast<FusedConvBNNode*>(t->getOwner());
272             fcbn->Checkpoint(tBuf, n, checkpoint_format_);
273           }
274         }
275       }
276     }
277 
278     if((tenType == CONVWEIGHT) || (tenType == CONVBIAS))
279     {
280       if(nntype == "Convolution")
281       {
282         ConvNode* cn = dynamic_cast<ConvNode*>(t->getOwner());
283         cn->Checkpoint(tBuf, n, checkpoint_format_);
284         if(current_epoch_ == 30 || current_epoch_ == 60 || current_epoch_ == 80)
285         {
286           n = checkpoint_dir_ + to_string(current_epoch_) + "/" + tn;
287           if(buftype == HISTORY)
288             n = n + "_history";
289           else if(buftype == DIFF)
290             n = n + "_diff";
291           cn->Checkpoint(tBuf, n, checkpoint_format_);
292         }
293       }
294       else if(nntype == "FusedConvBN")
295       {
296         FusedConvBNNode* fcbn = dynamic_cast<FusedConvBNNode*>(t->getOwner());
297         fcbn->Checkpoint(tBuf, n, checkpoint_format_);
298         if(current_epoch_ == 30 || current_epoch_ == 60 || current_epoch_ == 80)
299         {
300           n = checkpoint_dir_ + to_string(current_epoch_) + "/" + tn;
301           if(buftype == HISTORY)
302             n = n + "_history";
303           else if(buftype == DIFF)
304             n = n + "_grad";
305           fcbn->Checkpoint(tBuf, n, checkpoint_format_);
306         }
307       }
308     }
309     else if((tenType == FCWEIGHT) || (tenType == FCBIAS))
310     {
311       FCNode* fn = dynamic_cast<FCNode*>(t->getOwner());
312       fn->Checkpoint(tBuf, n, checkpoint_format_);
313       if(current_epoch_ == 30 || current_epoch_ == 60 || current_epoch_ == 80)
314       {
315         n = checkpoint_dir_ + to_string(current_epoch_) + "/" + tn;
316         if(buftype == HISTORY)
317           n = n + "_history";
318         else if(buftype == DIFF)
319           n = n + "_grad";
320         fn->Checkpoint(tBuf, n, checkpoint_format_);
321       }
322     }
323     else if((tenType == BNORMSCALE) || (tenType == BNORMSHIFT) || (tenType == BNORMMEAN) || (tenType == BNORMVAR))
324     {
325       if(nntype == "FusedBatchNorm")
326       {
327         FusedBNormNode* bn = dynamic_cast<FusedBNormNode*>(t->getOwner());
328         bn->Checkpoint(tBuf, n, checkpoint_format_);
329         if(current_epoch_ == 30 || current_epoch_ == 60 || current_epoch_ == 80)
330         {
331           n = checkpoint_dir_ + to_string(current_epoch_) + "/" + tn;
332           if(buftype == HISTORY)
333             n = n + "_history";
334           else if(buftype == DIFF)
335             n = n + "_grad";
336           bn->Checkpoint(tBuf, n, checkpoint_format_);
337         }
338       }
339       else if(nntype == "FusedConvBN")
340       {
341         FusedConvBNNode* fcbn = dynamic_cast<FusedConvBNNode*>(t->getOwner());
342         fcbn->Checkpoint(tBuf, n, checkpoint_format_);
343         if(current_epoch_ == 30 || current_epoch_ == 60 || current_epoch_ == 80)
344         {
345           n = checkpoint_dir_ + to_string(current_epoch_) + "/" + tn;
346           if(buftype == HISTORY)
347             n = n + "_history";
348           else if(buftype == DIFF)
349             n = n + "_grad";
350           fcbn->Checkpoint(tBuf, n, checkpoint_format_);
351         }
352       }
353     }
354   }
355 }
356 
read_checkpoint_file(TensorBuf * tBuf,string filename,string format)357 void MLEngine::read_checkpoint_file(TensorBuf* tBuf, string filename, string format)
358 {
359   long long int bytes = tBuf->getBufferSize();
360   int dtype = tBuf->getDataType();
361 
362   void* ptr;
363   ptr = tBuf->getBuffer();
364 
365   FILE* f;
366   if(format == "binary")
367   {
368     f = fopen(filename.c_str(), "rb");
369     assert(f != NULL);
370     size_t b = fread(ptr, 1, bytes, f);
371     assert((long long int)b == bytes);
372   }
373   else
374   {
375     printf("Reading from %s\n",filename.c_str());
376     f = fopen(filename.c_str(), "r");
377     assert(f != NULL);
378     if(dtype == DT_FLOAT)
379     {
380       float* p = (float*)ptr;
381       for(int i=0; i < bytes/sizeof(float); i++)
382         fscanf(f, "%f", &p[i]);
383     }
384   }
385   fclose(f);
386 
387   if(data_type_ == BF16 && (filename.find("wt") != filename.npos))
388     if(filename.find("history") == filename.npos)
389       convert_f32_bf16((float*)ptr, (libxsmm_bfloat16*)tBuf->getLPBuffer(), bytes/sizeof(float), 0);
390 
391 }
392 
load_checkpoint(TensorList L,int buftype,string format)393 void MLEngine::load_checkpoint(TensorList L, int buftype, string format)
394 {
395   TensorBuf* tBuf;
396 
397   for(Iter it=L.begin(); it != L.end(); it++)
398   {
399     Tensor* t = it->t;
400     int tenType = t->getType();
401     if((tenType != CONVWEIGHT) && (tenType != CONVBIAS) && (tenType != FCWEIGHT) && (tenType != FCBIAS))
402       if((tenType != BNORMSCALE) && (tenType != BNORMSHIFT) && (tenType != BNORMMEAN) && (tenType != BNORMVAR))
403       continue;
404 
405     bool found = false;
406     for(int index=0; index<t->getNumDataBuffers(); index++)
407     {
408       tBuf = t->getBuf(index);
409       if(tBuf->getBufferType() == buftype)
410       {
411         found = true;
412         break;
413       }
414     }
415     if(!found) continue;
416 
417     string n = checkpoint_dir_ + "/" + t->getTensorName();
418 
419     if(buftype == HISTORY)
420       n = n + "_history";
421 
422     size_t pos;
423     while((pos = n.find("/", 10)) != n.npos)
424       n.replace(pos, 1, 1, '_');
425     read_checkpoint_file(tBuf, n, format);
426   }
427 }
428 
canary_check(void * ptr,vector<int> & cp,int nc)429 void MLEngine::canary_check(void* ptr, vector<int>& cp, int nc)
430 {
431   if(ptr == NULL)
432   {
433     printf("FATAL: NULL pointer to buffer\n");
434     //exit(1);
435   }
436 
437   int *p = (int*)ptr;
438   for(int i=0; i<START_GUARD_BAND/sizeof(int); i++)
439   {
440    // printf("p[%d] = %x\n",i, p[i]);
441     if(p[i] != 0x7f7f7f7f)
442     {
443       printf("Fatal: canary value overwritten at %d in buffer at %p\n",i, ptr);
444       //exit(1);
445     }
446   }
447 
448   void *vp = (void*)(ptr + START_GUARD_BAND);
449 
450   for(int i=0; i<nc; i++)
451   {
452     int next = cp[i];
453     vp = (void*)(vp + next);
454     int *pp = (int*)vp;
455     for(int j=0; j<END_GUARD_BAND/sizeof(int); j++)
456     {
457      // printf("pp[%d] = %x\n",j, pp[j]);
458       if(pp[j] != 0x7f7f7f7f)
459       {
460         printf("Fatal: canary value overwritten at %d in buffer at %p\n",j,pp);
461         //exit(1);
462       }
463     }
464     vp += END_GUARD_BAND;
465   }
466 }
467 
waitForComms(string tenType)468 void MLEngine:: waitForComms(string tenType)
469 {
470 #ifdef USE_MLSL
471   if(tenType=="WEIGHT")
472   {
473     if(!wtgrad_comms_vec.empty())
474     {
475       for(int i=0; i<wtgrad_comms_vec.size(); i++)
476         wtgrad_comms_vec[i]->GetParameterSet(0)->WaitGradientComm();
477     }
478   }
479   else if(tenType=="BIAS")
480   {
481     if(!bias_grad_comms_vec.empty())
482     {
483       for(int i=0; i<bias_grad_comms_vec.size(); i++)
484       {
485         bias_grad_comms_vec[i]->GetParameterSet(0)->WaitGradientComm();
486         bias_grad_comms_vec[i]->GetParameterSet(1)->WaitGradientComm();
487         bias_grad_comms_vec[i]->GetParameterSet(2)->WaitGradientComm();
488         bias_grad_comms_vec[i]->GetParameterSet(3)->WaitGradientComm();
489       }
490     }
491   }
492   else if(tenType=="COMBO")
493   {
494     if(!combo_grad_comms_vec.empty())
495     {
496       for(int i=0; i<combo_grad_comms_vec.size(); i++)
497       {
498         combo_grad_comms_vec[i]->GetParameterSet(0)->WaitGradientComm();
499         combo_grad_comms_vec[i]->GetParameterSet(1)->WaitGradientComm();
500         combo_grad_comms_vec[i]->GetParameterSet(2)->WaitGradientComm();
501         combo_grad_comms_vec[i]->GetParameterSet(3)->WaitGradientComm();
502         combo_grad_comms_vec[i]->GetParameterSet(4)->WaitGradientComm();
503       }
504     }
505   }
506 #endif
507 }
508 
run(int mode)509 void MLEngine::run(int mode)
510 {
511   if(mode == TRAIN)
512   {
513     if(load_from_checkpoint_)
514     {
515       FILE *f = fopen("checkpoint", "r");
516       if(f != NULL)
517       {
518         fscanf(f, "%d %f %f\n",&current_epoch_, &lr_, &scf_);
519         fclose(f);
520       }
521       else
522         printf("No checkpoint state file to read\n");
523 
524       if(current_epoch_ != num_epochs_ - 1)
525         current_epoch_++;
526       load_checkpoint(wTList_, DATA, checkpoint_format_);
527       load_checkpoint(wTList_, HISTORY, checkpoint_format_);
528       load_checkpoint(biasTList_, DATA, checkpoint_format_);
529       load_checkpoint(biasTList_, HISTORY, checkpoint_format_);
530       load_checkpoint(statsTList_, DATA, checkpoint_format_);
531 
532 #ifdef _OPENMP
533 #pragma omp parallel
534 #endif
535       {
536         int tid = omp_get_thread_num();
537         int ntps = num_threads_/NUM_NUMA_NODES;
538         int n = tid/ntps;
539         int w = total_weights_;
540         int b = total_biases_;
541 
542         if(n != 0 && tid % ntps == 0)
543         {
544           float *wptr = (float*)weight_buf_[n];
545 #if 1
546           float *bptr = (float*)bias_buf_[n];
547           float *sptr = (float*)stats_buf_[n];
548 #endif
549 
550 #pragma omp simd
551           for(int i=0; i<w; i++)
552             wptr[i] = ((float*)weight_buf_[0])[i];
553 
554 #if 1
555 #pragma omp simd
556           for(int i=0; i<b; i++)
557           {
558             bptr[i] = ((float*)bias_buf_[0])[i];
559             sptr[i] = ((float*)stats_buf_[0])[i];
560           }
561 #endif
562           if(lpweight_buf_[0] != NULL)
563           {
564             libxsmm_bfloat16 *lwptr = (libxsmm_bfloat16*)lpweight_buf_[n];
565 #pragma omp simd
566             for(int i=0; i<w; i++)
567               lwptr[i] = ((libxsmm_bfloat16*)lpweight_buf_[0])[i];
568           }
569         }
570       }
571       load_from_checkpoint_ = false;
572     }
573 
574     fflush(stdout);
575 
576 #ifdef USE_MLSL
577      data_parallelism->Barrier(MLSL::GT_DATA);
578 #endif
579 
580     // current_epoch_ is set in create() function or by checkpoint code above
581     for(; current_epoch_ < num_epochs_; current_epoch_++)
582     {
583       // Tell data node that it should use training data
584       exec_mode_ = TRAIN;
585       if(global_node_id_ == 0)
586       {
587         printf("===========================================\n");
588         printf("TRAIN mode, epoch %d, training batches %d\n", current_epoch_, num_train_batches_);
589         printf("===========================================\n");
590       }
591 
592       // Run training network for an epoch
593       struct timeval tvs, tve, tvts, tvte, tvis, tvie;
594       double fbtime, runtime = 0;
595 
596       for(; current_batch_<num_train_batches_; current_batch_++)
597       {
598         if(global_node_id_ == 0 && current_batch_ % 100 == 0)
599           printf("Executing batch number %d\n",current_batch_);
600 
601         gettimeofday(&tvs, NULL);
602 
603         for(auto it = etg_[TRAIN].begin(); it != etg_[TRAIN].end(); it++)
604         {
605 #ifdef TIMING
606           gettimeofday(&tvts, NULL);
607 #endif
608           (*it)->invoke();
609 
610 #ifdef TIMING
611           gettimeofday(&tvte, NULL);
612           double tasktime = (tvte.tv_sec*1e6 + tvte.tv_usec) - (tvts.tv_sec*1e6 + tvts.tv_usec);
613           NNNode *nn = dynamic_cast<NNNode*>((*it)->getNode());
614           if(global_node_id_ == 0)
615             printf("Node %s (task %d) time = %f ms\n",nn->getNodeName().c_str(), (*it)->getBasicTaskId(), tasktime/1000);
616 #endif
617         }
618 
619         if(solver_->getGlobalFlag())
620         {
621 #ifdef TIMING
622           gettimeofday(&tvis, NULL);
623 #endif
624 
625 #ifdef DUMP_WT
626           if(global_node_id_ == 0)
627             if(current_epoch_ == 30 || current_epoch_ == 60 || current_epoch_ == 80)
628               if(current_batch_ == num_train_batches_-1)
629                 checkpoint(wTList_, DIFF);
630 #endif
631 
632 #ifdef USE_MLSL
633           waitForComms("WEIGHT");
634           waitForComms("BIAS");
635           waitForComms("COMBO");
636 #endif
637 
638 #ifdef MLSL
639           data_parallelism->Barrier(MLSL::GT_DATA);
640 #endif
641 
642 #if 0
643           solver_->applyUpdate((float**)weight_buf_, (float**)winc_buf_, wdiff_buf_, total_weights_, (float**)wt_lr_mult_, (float**)wt_decay_mult_, "WEIGHT");
644 #else
645           solver_->applyUpdate((float**)weight_buf_, (float**)winc_buf_, wdiff_buf_, total_weights_, 1.0, 1.0, "WEIGHT");
646 #endif
647           if(data_type_ == BF16)
648             convert_f32_bf16((float**)weight_buf_, (libxsmm_bfloat16**)lpweight_buf_, total_weights_);
649 
650 #if 0
651           solver_->applyUpdate((float**)bias_buf_, (float**)biinc_buf_, bidiff_buf_, total_biases_, (float**)bias_lr_mult_, (float**)bias_decay_mult_, "BIAS");
652 #else
653 #if 1
654           solver_->applyUpdate((float**)bias_buf_, (float**)biinc_buf_, bidiff_buf_, total_biases_, 1.0, 0.0, "BIAS");
655 #else
656           solver_->applyUpdate((float*)bias_buf_, (float*)biinc_buf_, bidiff_buf_, total_biases_, 1.0, 0.0, "BIAS");
657 #endif
658 #endif
659 
660 #ifdef TIMING
661           gettimeofday(&tvie, NULL);
662           double sgdtime = (tvie.tv_sec + tvie.tv_usec*1e-6) - (tvis.tv_sec + tvis.tv_usec*1e-6);
663           printf("global sgd time: %f ms\n",sgdtime*1000);
664 #endif
665         }
666 
667         gettimeofday(&tve, NULL);
668         fbtime = (tve.tv_sec + tve.tv_usec*1e-6) - (tvs.tv_sec + tvs.tv_usec*1e-6);
669         if(global_node_id_ == 0 && current_batch_ % 100 == 0)
670           printf("Fwd-Bwd time: %f ms\n",fbtime*1000);
671 
672         if ( current_batch_ > 1 )
673           runtime += fbtime;
674 
675 #ifdef CANARY_CHECK
676         canary_check(input_buf_, input_can_ptr, ic);
677         canary_check(fact_buf_, fact_can_ptr, fac);
678         canary_check(bact_buf_, bact_can_ptr, bac);
679 #endif
680       }
681 
682       current_batch_ = 0;
683 
684       if ( num_train_batches_ > 1 ) {
685         char hostname[HOST_NAME_MAX + 1];
686         gethostname(hostname, HOST_NAME_MAX + 1);
687         printf("%s; Average Training time = %f seconds", hostname, runtime/((double)(num_train_batches_-2)));
688         if(runtime > 0) {
689           printf("; Average Training throughput = %f images/s\n", ((double)(batch_size_*(num_train_batches_-2)))/runtime);
690         } else {
691           printf("\n");
692         }
693       }
694 
695       // Checkpoint weights and biases
696       if(global_node_id_ == 0)
697       {
698         checkpoint(wTList_, DATA);
699         checkpoint(wTList_, HISTORY);
700         checkpoint(biasTList_, DATA);
701         checkpoint(biasTList_, HISTORY);
702         checkpoint(statsTList_, DATA);
703 
704 #ifdef DUMP_ACT_DATA
705         if(current_epoch_ == 30 || current_epoch_ == 60 || current_epoch_ == 80)
706         {
707           checkpoint(outTList_, DATA);
708           checkpoint(outTList_, DIFF);
709         }
710 #endif
711 
712         FILE* f = fopen("checkpoint", "w");
713         if(f != NULL)
714         {
715           fprintf(f, "%d %10g %10g\n",current_epoch_, lr_, scf_);
716           fclose(f);
717         }
718       }
719 #ifdef USE_MLSL
720       data_parallelism->Barrier(MLSL::GT_DATA);
721 #endif
722 
723       // Tell data node that it should use test data
724       exec_mode_ = VAL;
725 
726       if(global_node_id_ == 0)
727       {
728         printf("===========================================\n");
729         printf("VAL mode, testing batches %d\n", num_test_batches_);
730         printf("===========================================\n");
731       }
732 
733       // Run validation network at end of each epoch
734       for(; current_batch_<num_test_batches_; current_batch_++)
735       {
736         for(int v=0; v<num_test_views_; v++)
737           for(auto it = etg_[VAL].begin(); it != etg_[VAL].end(); it++)
738             (*it)->invoke();
739       }
740 
741       current_batch_ = 0;
742 
743 #ifdef CANARY_CHECK
744       canary_check(input_buf_, input_can_ptr, ic);
745       canary_check(fact_buf_, fact_can_ptr, fac);
746       canary_check(weight_buf_, wt_can_ptr, wtc);
747       canary_check(bias_buf_, bias_can_ptr, bic);
748 #endif
749     }
750 
751 #ifdef USE_MLSL
752     MLSL::Environment::GetEnv().Free(input_buf_);
753     MLSL::Environment::GetEnv().Free(fact_buf_);
754     MLSL::Environment::GetEnv().Free(bact_buf_);
755 #else
756     libxsmm_free(input_buf_);
757     libxsmm_free(fact_buf_);
758     libxsmm_free(bact_buf_);
759 #endif
760 
761     for(int n=0; n<NUM_NUMA_NODES; n++)
762     {
763 #ifdef USE_MLSL
764       MLSL::Environment::GetEnv().Free(weight_buf_[n]);
765       if(lpweight_buf_[n] != NULL)
766         MLSL::Environment::GetEnv().Free(lpweight_buf_[n]);
767       MLSL::Environment::GetEnv().Free(wdiff_buf_[n]);
768       if(lpwdiff_buf_[n] != NULL)
769         MLSL::Environment::GetEnv().Free(lpwdiff_buf_[n]);
770       MLSL::Environment::GetEnv().Free(winc_buf_[n]);
771 #if 1
772       MLSL::Environment::GetEnv().Free(bias_buf_[n]);
773       MLSL::Environment::GetEnv().Free(bidiff_buf_[n]);
774       MLSL::Environment::GetEnv().Free(biinc_buf_[n]);
775       MLSL::Environment::GetEnv().Free(stats_buf_[n]);
776 #else
777       MLSL::Environment::GetEnv().Free(bias_buf_);
778       MLSL::Environment::GetEnv().Free(bidiff_buf_);
779       MLSL::Environment::GetEnv().Free(biinc_buf_);
780       MLSL::Environment::GetEnv().Free(stats_buf_);
781 #endif
782 #else
783       libxsmm_free(weight_buf_[n]);
784       libxsmm_free(wdiff_buf_[n]);
785       if(lpweight_buf_[n] != NULL)
786         libxsmm_free(lpweight_buf_[n]);
787       if(lpwdiff_buf_[n] != NULL)
788         libxsmm_free(lpwdiff_buf_[n]);
789       libxsmm_free(winc_buf_[n]);
790 #if 1
791       libxsmm_free(bias_buf_[n]);
792       libxsmm_free(bidiff_buf_[n]);
793       libxsmm_free(biinc_buf_[n]);
794       libxsmm_free(stats_buf_[n]);
795 #else
796       libxsmm_free(bias_buf_);
797       libxsmm_free(bidiff_buf_);
798       libxsmm_free(biinc_buf_);
799       libxsmm_free(stats_buf_);
800 #endif
801 #endif
802     }
803   }
804   else if(mode == TEST)
805   {
806     exec_mode_ = TEST;
807 
808     FILE *f = fopen("checkpoint", "r");
809     fscanf(f, "%d %f %f\n",&current_epoch_, &lr_, &scf_);
810     fclose(f);
811 
812     printf("====================================================================\n");
813     printf("TEST mode, testing batches %d, scaling factor %.10f\n", num_test_batches_, scf_);
814     printf("====================================================================\n");
815 
816     load_checkpoint(wTList_, DATA, checkpoint_format_);
817     load_checkpoint(biasTList_, DATA, checkpoint_format_);
818     load_checkpoint(statsTList_, DATA, checkpoint_format_);
819 
820 #ifdef _OPENMP
821 #pragma omp parallel
822 #endif
823     {
824       int tid = omp_get_thread_num();
825       int ntps = num_threads_/NUM_NUMA_NODES;
826       int n = tid/ntps;
827       int w = total_weights_;
828       int b = total_biases_;
829 
830       if(n != 0 && tid % ntps == 0)
831       {
832         float *wptr = (float*)weight_buf_[n];
833 #if 1
834         float *bptr = (float*)bias_buf_[n];
835         float *sptr = (float*)stats_buf_[n];
836 #endif
837 
838 #pragma omp simd
839         for(int i=0; i<w; i++)
840           wptr[i] = ((float*)weight_buf_[0])[i];
841 
842 #if 1
843 #pragma omp simd
844         for(int i=0; i<b; i++)
845         {
846           bptr[i] = ((float*)bias_buf_[0])[i];
847           sptr[i] = ((float*)stats_buf_[0])[i];
848         }
849 #endif
850         if(lpweight_buf_[0] != NULL)
851         {
852           libxsmm_bfloat16 *lwptr = (libxsmm_bfloat16*)lpweight_buf_[n];
853 #pragma omp simd
854           for(int i=0; i<w; i++)
855             lwptr[i] = ((libxsmm_bfloat16*)lpweight_buf_[0])[i];
856         }
857       }
858     }
859 
860     // Run test network when command-line mode is set to "test"
861     for(int b=0; b<num_test_batches_; b++)
862     {
863       for(auto it = etg_[TEST].begin(); it != etg_[TEST].end(); it++)
864         (*it)->invoke();
865     }
866   }
867 }
868 
convert_f32_bf16(float * in,libxsmm_bfloat16 * out,int len,int numa_node)869 void MLEngine::convert_f32_bf16(float* in, libxsmm_bfloat16* out, int len, int numa_node)
870 {
871 
872 #ifdef _OPENMP
873 #pragma omp parallel
874 #endif
875   {
876     int tid = omp_get_thread_num();
877     int ntps = num_threads_/NUM_NUMA_NODES;
878     int n = tid/ntps;
879     int ltid = tid - numa_node*ntps;
880 
881     if(n == numa_node)
882     {
883       int jobs = (len % ntps == 0) ? len/ntps : len/ntps + 1;
884       int tb = (ltid*jobs < len) ? ltid*jobs : len;
885       int te = ((ltid+1)*jobs < len) ? (ltid+1)*jobs : len;
886 
887       for (int i = tb; i < te; i+=16 ) {
888         __m512  vfp32  = gxm_fp32_to_bfp16_rne_adjustment_avx512f( _mm512_loadu_ps( in+i ) );
889         __m256i vbfp16 = gxm_fp32_to_bfp16_truncate_avx512f( vfp32 );
890         _mm256_storeu_si256( (__m256i*)(out+i), vbfp16 );
891       }
892     }
893   }
894 }
895 
convert_f32_bf16(float ** in,libxsmm_bfloat16 ** out,int len)896 void MLEngine::convert_f32_bf16(float** in, libxsmm_bfloat16** out, int len)
897 {
898 #ifdef _OPENMP
899 #pragma omp parallel
900 #endif
901   {
902     int tid = omp_get_thread_num();
903     int ntps = num_threads_/NUM_NUMA_NODES;
904     int n = tid/ntps;
905     int ltid = tid - n*ntps;
906 
907     float *inp = in[n];
908     libxsmm_bfloat16 *outp = out[n];
909 
910     int jobs = (len % ntps == 0) ? len/ntps : len/ntps + 1;
911     int tb = (ltid*jobs < len) ? ltid*jobs : len;
912     int te = ((ltid+1)*jobs < len) ? (ltid+1)*jobs : len;
913 
914     for (int i = tb; i < te; i+=16 ) {
915       __m512  vfp32  = gxm_fp32_to_bfp16_rne_adjustment_avx512f(_mm512_loadu_ps(inp + i));
916       __m256i vbfp16 = gxm_fp32_to_bfp16_truncate_avx512f(vfp32);
917       _mm256_storeu_si256( (__m256i*)(outp+i), vbfp16 );
918     }
919   }
920 }
convert_bf16_f32(libxsmm_bfloat16 * in,float * out,int len)921 void MLEngine::convert_bf16_f32(libxsmm_bfloat16* in, float* out, int len)
922 {
923   int i;
924 
925 #ifdef _OPENMP
926 #pragma omp parallel for private(i)
927 #endif
928   for ( i = 0; i < len; i+=16 ) {
929     __m256i vbfp16    = _mm256_loadu_si256( (const __m256i*)(in+i) );
930     __m512  vfp32     = gxm_bfp16_to_fp32_avx512f( vbfp16 );
931     _mm512_storeu_ps( out+i, vfp32 );
932   }
933 }
934 
allocate_memory(string tenType,TensorList L,int buftype,vector<int> & can_ptr,int * nc,long long int * bufsize)935 void MLEngine::allocate_memory(string tenType, TensorList L, int buftype, vector<int>& can_ptr, int* nc, long long int* bufsize)
936 {
937   bool ttp = false; //(tenType != "WEIGHT") & (tenType != "BIAS");
938 
939   long long int s = ttp ? START_GUARD_BAND : 0;
940   TensorBuf* tBuf;
941   int num_canaries = 0;
942 
943   float* lrptr, *decptr;
944 
945   // Get total buffer size required for tensors of type buftype
946   for(Iter it=L.begin(); it != L.end(); it++)
947   {
948     Tensor* t = it->t;
949 
950     bool found = false;
951     for(int i=0; i<t->getNumDataBuffers(); i++)
952     {
953       tBuf = t->getBuf(i);
954       if(tBuf->getBufferType() == buftype)
955       {
956         found = true;
957         break;
958       }
959     }
960     if(!found) continue;
961 
962     long long int size = tBuf->getBufferSize();
963     if(size > 0)
964     {
965       if(global_node_id_ == 0)
966       {
967         printf("Tensor %s needs %lld bytes for buffer %d\n", t->getTensorName().c_str(), size, buftype);
968         fflush(stdout);
969       }
970       s += size;
971       if(ttp)
972         s += END_GUARD_BAND;
973 
974       if(ttp)
975         num_canaries++;
976     }
977   }
978 
979   if(tenType == "WEIGHT")
980     total_weights_ = s/sizeof(float);
981   else if(tenType == "BIAS" || tenType == "STATS")
982     total_biases_ = s/sizeof(float);
983 
984   if(solver_->getGlobalFlag())
985   {
986     if(tenType == "WEIGHT")
987     {
988 #ifdef BF16_MLSL
989       if(buftype == DIFF)
990       {
991         if(data_type_ == FLOAT)
992           total_weights_ = s/sizeof(float);
993         else if(data_type_ == BF16)
994           total_weights_ = s/sizeof(libxsmm_bfloat16);
995       }
996       else
997 #endif
998         total_weights_ = s/sizeof(float);
999 
1000       int factor = num_threads_ * VLEN;
1001       int nwt = (total_weights_ + factor - 1)/factor;
1002       total_weights_ = nwt * factor;
1003 
1004 #ifdef BF16_MLSL
1005       if(buftype == DIFF)
1006       {
1007         if(data_type_ == FLOAT)
1008           s = total_weights_ * sizeof(float);
1009         else if(data_type_ == BF16)
1010           s = total_weights_ * sizeof(libxsmm_bfloat16);
1011       }
1012       else
1013 #endif
1014         s = total_weights_ * sizeof(float);
1015     }
1016     else if(tenType == "BIAS" || tenType == "STATS")
1017     {
1018       total_biases_ = s / sizeof(float);
1019       int factor = num_threads_ * VLEN;
1020       int nwt = (total_biases_ + factor - 1)/factor;
1021       total_biases_ = nwt * factor;
1022 
1023       s = total_biases_ * sizeof(float);
1024     }
1025   }
1026 
1027   // Number of guard bands in tensor; used for canary checking
1028   *nc = num_canaries;
1029 
1030   // Allocate memory
1031 #ifdef BF16_MLSL
1032   bool lp = (data_type_ == BF16) && (tenType=="WEIGHT") && (buftype == DATA);
1033 #else
1034   bool lp = (data_type_ == BF16) && (tenType=="WEIGHT");
1035 #endif
1036 
1037   void *buf_;
1038   void **ptrptr, **lptrptr=NULL;
1039 
1040 #if 0 //def USE_MLSL
1041   s = ALIGN_SIZE(s, 2097152);
1042 #endif
1043 
1044   if(tenType=="INPUT")
1045   {
1046 #ifdef USE_MLSL
1047     buf_ = (void*)MLSL::Environment::GetEnv().Alloc(s, 2097152);
1048 #else
1049     buf_ = (void*)libxsmm_aligned_malloc(s, 2097152);
1050 #endif
1051     input_buf_ = buf_;
1052   }
1053   else if(tenType == "FACT")
1054   {
1055 #ifdef USE_MLSL
1056     buf_ = (void*)MLSL::Environment::GetEnv().Alloc(s, 2097152);
1057 #else
1058     buf_ = (void*)libxsmm_aligned_malloc(s, 2097152);
1059 #endif
1060     fact_buf_ = buf_;
1061   }
1062   else if(tenType == "WEIGHT")
1063   {
1064     if(buftype == DATA)
1065     {
1066       for(int n=0; n<NUM_NUMA_NODES; n++)
1067       {
1068 #ifdef USE_MLSL
1069         weight_buf_[n] = (void*)MLSL::Environment::GetEnv().Alloc(s, 2097152);
1070 #else
1071         weight_buf_[n] = (void*)libxsmm_aligned_malloc(s, 2097152);
1072 #endif
1073         if(lp)
1074 #ifdef USE_MLSL
1075           lpweight_buf_[n] = (void*)MLSL::Environment::GetEnv().Alloc(s/sizeof(libxsmm_bfloat16), 2097152);
1076 #else
1077           lpweight_buf_[n] = (void*)libxsmm_aligned_malloc(s/sizeof(libxsmm_bfloat16), 2097152);
1078 #endif
1079       }
1080       buf_ = weight_buf_[0];
1081       ptrptr = weight_buf_;
1082       if(lp)
1083         lptrptr = lpweight_buf_;
1084     }
1085     else if(buftype == DIFF)
1086     {
1087       for(int n=0; n<NUM_NUMA_NODES; n++)
1088       {
1089 #ifdef USE_MLSL
1090         wdiff_buf_[n] = (void*)MLSL::Environment::GetEnv().Alloc(s, 2097152);
1091 #else
1092         wdiff_buf_[n] = (void*)libxsmm_aligned_malloc(s, 2097152);
1093 #endif
1094         if(lp)
1095 #ifdef USE_MLSL
1096           lpwdiff_buf_[n] = (void*)MLSL::Environment::GetEnv().Alloc(s/sizeof(libxsmm_bfloat16), 2097152);
1097 #else
1098           lpwdiff_buf_[n] = (void*)libxsmm_aligned_malloc(s/sizeof(libxsmm_bfloat16), 2097152);
1099 #endif
1100       }
1101       buf_ = wdiff_buf_[0];
1102       ptrptr = wdiff_buf_;
1103       if(lp)
1104         lptrptr = lpwdiff_buf_;
1105     }
1106     else if(buftype == HISTORY)
1107     {
1108       for(int n=0; n<NUM_NUMA_NODES; n++)
1109 #ifdef USE_MLSL
1110         winc_buf_[n] = (void*)MLSL::Environment::GetEnv().Alloc(s, 2097152);
1111 #else
1112         winc_buf_[n] = (void*)libxsmm_aligned_malloc(s, 2097152);
1113 #endif
1114       buf_ = winc_buf_[0];
1115       ptrptr = winc_buf_;
1116     }
1117   }
1118   else if(tenType == "BACT")
1119   {
1120 #ifdef USE_MLSL
1121     buf_ = (void*)MLSL::Environment::GetEnv().Alloc(s, 2097152);
1122 #else
1123     buf_ = (void*)libxsmm_aligned_malloc(s, 2097152);
1124 #endif
1125     bact_buf_ = buf_;
1126   }
1127   else if(tenType == "STATS")
1128   {
1129 #if 1
1130     for(int n=0; n<NUM_NUMA_NODES; n++)
1131 #ifdef USE_MLSL
1132       stats_buf_[n] = (void*)MLSL::Environment::GetEnv().Alloc(s, 2097152);
1133 #else
1134       stats_buf_[n] = (void*)libxsmm_aligned_malloc(s, 2097152);
1135 #endif
1136     buf_ = stats_buf_[0];
1137     ptrptr = stats_buf_;
1138 #else
1139 #ifdef USE_MLSL
1140     stats_buf_ = (void*)MLSL::Environment::GetEnv().Alloc(s, 2097152);
1141 #else
1142     stats_buf_ = (void*)libxsmm_aligned_malloc(s, 2097152);
1143 #endif
1144     buf_ = stats_buf_;
1145 #endif
1146   }
1147   else if(tenType == "BIAS")
1148   {
1149     if(buftype == DATA)
1150     {
1151 #if 1
1152       for(int n=0; n<NUM_NUMA_NODES; n++)
1153 #ifdef USE_MLSL
1154         bias_buf_[n] = (void*)MLSL::Environment::GetEnv().Alloc(s, 2097152);
1155 #else
1156         bias_buf_[n] = (void*)libxsmm_aligned_malloc(s, 2097152);
1157 #endif
1158       buf_ = bias_buf_[0];
1159       ptrptr = bias_buf_;
1160 #else
1161 #ifdef USE_MLSL
1162       bias_buf_ = (void*)MLSL::Environment::GetEnv().Alloc(s, 2097152);
1163 #else
1164       bias_buf_ = (void*)libxsmm_aligned_malloc(s, 2097152);
1165 #endif
1166       buf_ = bias_buf_;
1167 #endif
1168     }
1169     else if(buftype == DIFF)
1170     {
1171 #if 1
1172       for(int n=0; n<NUM_NUMA_NODES; n++)
1173 #ifdef USE_MLSL
1174         bidiff_buf_[n] = (void*)MLSL::Environment::GetEnv().Alloc(s, 2097152);
1175 #else
1176         bidiff_buf_[n] = (void*)libxsmm_aligned_malloc(s, 2097152);
1177 #endif
1178       buf_ = bidiff_buf_[0];
1179       ptrptr = bidiff_buf_;
1180 #else
1181 #ifdef USE_MLSL
1182       bidiff_buf_ = (void*)MLSL::Environment::GetEnv().Alloc(s, 2097152);
1183 #else
1184       bidiff_buf_ = (void*)libxsmm_aligned_malloc(s, 2097152);
1185 #endif
1186       buf_ = bidiff_buf_;
1187 #endif
1188     }
1189     else if(buftype == HISTORY)
1190     {
1191 #if 1
1192       for(int n=0; n<NUM_NUMA_NODES; n++)
1193 #ifdef USE_MLSL
1194         biinc_buf_[n] = (void*)MLSL::Environment::GetEnv().Alloc(s, 2097152);
1195 #else
1196         biinc_buf_[n] = (void*)libxsmm_aligned_malloc(s, 2097152);
1197 #endif
1198       buf_ = biinc_buf_[0];
1199       ptrptr = biinc_buf_;
1200 #else
1201 #ifdef USE_MLSL
1202       biinc_buf_ = (void*)MLSL::Environment::GetEnv().Alloc(s, 2097152);
1203 #else
1204       biinc_buf_ = (void*)libxsmm_aligned_malloc(s, 2097152);
1205 #endif
1206       buf_ = biinc_buf_;
1207 #endif
1208     }
1209   }
1210 
1211   // Total buffer size, including guard bands before and after each buffer (currntly 64 bytes long)
1212   *bufsize = s + (lp ? s/sizeof(libxsmm_bfloat16) : 0);
1213 
1214 #if 0
1215   printf("Tensor with buffers %d @ %p with total size %lld\n",buftype, buf_, s);
1216   fflush(stdout);
1217 
1218   if(buf_ != NULL)
1219   {
1220 #ifndef USE_NUMA
1221     memset(buf_, 0, s);
1222 #endif
1223   }
1224   else {
1225     printf("could not allocate tensor memory.. exiting\n");
1226     exit(-1);
1227   }
1228 
1229   if(lp && lpweight_buf_==NULL)
1230   {
1231     printf("could not allocate low precision weights memory.. exiting\n");
1232     exit(-1);
1233   }
1234 
1235   if(solver_->getGlobalFlag())
1236   {
1237     if(tenType == "WEIGHT" && buftype == DIFF)
1238     {
1239       for(int n=0; n<NUM_NUMA_NODES; n++)
1240       {
1241         wt_lr_mult_[n] = (float*)libxsmm_aligned_malloc(total_weights_*sizeof(float), 2097152);
1242         if(wt_lr_mult_[n] != NULL)
1243         {
1244           float *ptr = wt_lr_mult_[n];
1245 
1246 #ifdef _OPENMP
1247 #pragma omp parallel
1248 #endif
1249           {
1250             int tid = omp_get_thread_num();
1251             int ntps = num_threads_/NUM_NUMA_NODES;
1252             int s = tid/ntps;
1253             if(s == n && tid % ntps == 0)
1254               for(int i=0; i<total_weights_; i++)
1255                 ptr[i] = 0.0;
1256           }
1257         }
1258 
1259         wt_decay_mult_[n] = (float*)libxsmm_aligned_malloc(total_weights_*sizeof(float), 2097152);
1260         if(wt_decay_mult_[n] != NULL)
1261         {
1262           float *ptr = wt_decay_mult_[n];
1263 
1264 #ifdef _OPENMP
1265 #pragma omp parallel
1266 #endif
1267           {
1268             int tid = omp_get_thread_num();
1269             int ntps = num_threads_/NUM_NUMA_NODES;
1270             int s = tid/ntps;
1271             if(s == n && tid % ntps == 0)
1272               for(int i=0; i<total_weights_; i++)
1273                 ptr[i] = 0.0;
1274           }
1275         }
1276       }
1277       lrptr = wt_lr_mult_[0];
1278       decptr = wt_decay_mult_[0];
1279     }
1280     else if(tenType == "BIAS" && buftype == DIFF)
1281     {
1282       for(int n=0; n<NUM_NUMA_NODES; n++)
1283       {
1284         bias_lr_mult_[n] = (float*)_mm_malloc(total_biases_*sizeof(float), 64);
1285         if(bias_lr_mult_[n] != NULL)
1286         {
1287           float *ptr = bias_lr_mult_[n];
1288 
1289 #ifdef _OPENMP
1290 #pragma omp parallel
1291 #endif
1292           {
1293             int tid = omp_get_thread_num();
1294             int ntps = num_threads_/NUM_NUMA_NODES;
1295             int s = tid/ntps;
1296             if(s == n && tid % ntps == 0)
1297               for(int i=0; i<total_biases_; i++)
1298                 ptr[i] = 0.0;
1299           }
1300         }
1301 
1302         bias_decay_mult_[n] = (float*)_mm_malloc(total_biases_*sizeof(float), 64);
1303         if(bias_decay_mult_[n] != NULL)
1304         {
1305           float *ptr = bias_decay_mult_[n];
1306 
1307 #ifdef _OPENMP
1308 #pragma omp parallel
1309 #endif
1310           {
1311             int tid = omp_get_thread_num();
1312             int ntps = num_threads_/NUM_NUMA_NODES;
1313             int s = tid/ntps;
1314             if(s == n && tid % ntps == 0)
1315               for(int i=0; i<total_biases_; i++)
1316                 ptr[i] = 0.0;
1317           }
1318         }
1319       }
1320       lrptr = bias_lr_mult_[0];
1321       decptr = bias_decay_mult_[0];
1322     }
1323   }
1324 #endif
1325 
1326   if(ttp)
1327     memset(buf_, CANARY, START_GUARD_BAND);
1328 
1329   long long int bytes=0, lpbytes=0;
1330   int offset=0, bias_offset=0;
1331 
1332   //Set up tensor buffer pointers
1333   void* ptr = ttp ? buf_ + START_GUARD_BAND : buf_;
1334   void* lptr = lp ? lpweight_buf_[0] : NULL;
1335   void* lgptr = lp ? lpwdiff_buf_[0] : NULL;
1336 
1337   for(Iter it=L.begin(); it != L.end(); it++)
1338   {
1339     Tensor* t = it->t;
1340 
1341     bool found = false;
1342     for(int i=0; i<t->getNumDataBuffers(); i++)
1343     {
1344       tBuf = t->getBuf(i);
1345       if(tBuf->getBufferType() == buftype)
1346       {
1347         found = true;
1348         break;
1349       }
1350     }
1351     if(!found) continue;
1352 
1353     // Don't process Split nodes further for forward activations
1354     string nntype = dynamic_cast<NNNode*>(t->getOwner())->getNodeType();
1355     if(nntype.find("Split") != nntype.npos && buftype == DATA)
1356       continue;
1357 
1358     // Scrub or initialize buffers appropriately
1359     bytes = tBuf->getBufferSize();
1360     assert(ptr+bytes <= buf_+s);
1361 
1362     lpbytes = lp ? bytes/sizeof(libxsmm_bfloat16) : 0;
1363 
1364 #ifndef USE_NUMA
1365     if(t->getType() == INPUT || t->getType() == ACT)
1366     {
1367       if(bytes > 0)
1368         memset(ptr, 0, bytes);
1369     }
1370 #endif
1371 
1372     int dtype = tBuf->getDataType();
1373 
1374     // Set each node's tensor buffer pointers to the appropritate location in the global buffer
1375     if(tenType == "WEIGHT" || tenType == "BIAS" || tenType == "STATS")
1376     {
1377       if(buftype == DATA || buftype == DIFF)
1378       {
1379         tBuf->setBufferPtr(ptrptr);
1380         tBuf->setOffset(offset);
1381       }
1382       tBuf->setBuffer(ptr);
1383 
1384       if(lp)
1385       {
1386         if(buftype == DATA)
1387           tBuf->setLPBuffer(lptr);
1388         else if(buftype == DIFF)
1389           tBuf->setLPBuffer(lgptr);
1390         tBuf->setLPBufferPtr(lptrptr);
1391       }
1392     }
1393     else
1394       tBuf->setBuffer(ptr);
1395 
1396     // If weight or bias tensor, call corresponding intialization function (for training only)
1397     if(!is_inference_only())
1398     {
1399       int tType = t->getType();
1400       if(tType == CONVWEIGHT)
1401       {
1402         if(nntype == "FusedConvBN")
1403         {
1404           FusedConvBNNode *fcbn = dynamic_cast<FusedConvBNNode*>(t->getOwner());
1405           assert(bytes > 0);
1406           if(!load_from_checkpoint_)
1407           {
1408             fcbn->fillWeightBuffers(tBuf, buftype, bytes);
1409 #if 0
1410             if(lp)
1411               convert_f32_bf16((float*)ptr, (libxsmm_bfloat16*)lptr, lpbytes/sizeof(libxsmm_bfloat16), 0);
1412 #endif
1413           }
1414 #if 0
1415           if(solver_->getGlobalFlag())
1416             if(buftype == DIFF)
1417               if(data_type_ == FLOAT)
1418                 fcbn->fillWeightMultipliers(lrptr, decptr, bytes/sizeof(float));
1419               else if(data_type_ == BF16)
1420                 fcbn->fillWeightMultipliers(lrptr, decptr, bytes/sizeof(libxsmm_bfloat16));
1421 #endif
1422         }
1423         else if(nntype == "Convolution")
1424         {
1425           ConvNode* cn = dynamic_cast<ConvNode*>(t->getOwner());
1426           assert(bytes > 0);
1427           if(!load_from_checkpoint_)
1428           {
1429             cn->fillWeightBuffers(tBuf, buftype, bytes);
1430 #if 0
1431             if(lp)
1432               convert_f32_bf16((float*)ptr, (libxsmm_bfloat16*)lptr, lpbytes/sizeof(libxsmm_bfloat16), 0);
1433 #endif
1434           }
1435 
1436 #if 0
1437           if(solver_->getGlobalFlag())
1438             if(buftype == DIFF)
1439               if(data_type_ == FLOAT)
1440                 cn->fillWeightMultipliers(lrptr, decptr, bytes/sizeof(float));
1441               else if(data_type_ == BF16)
1442                 cn->fillWeightMultipliers(lrptr, decptr, bytes/sizeof(libxsmm_bfloat16));
1443 #endif
1444         }
1445       }
1446       else if(tType == CONVBIAS)
1447       {
1448         ConvNode* cn = dynamic_cast<ConvNode*>(t->getOwner());
1449         assert(bytes > 0);
1450         if(!load_from_checkpoint_)
1451           cn->fillBiasBuffers(tBuf, buftype, bytes);
1452 #if 0
1453         if(solver_->getGlobalFlag())
1454           if(buftype == DIFF)
1455             cn->fillBiasMultipliers(lrptr, decptr, bytes/sizeof(float));
1456 #endif
1457       }
1458       else if(tType == FCWEIGHT)
1459       {
1460         FCNode* fn = dynamic_cast<FCNode*>(t->getOwner());
1461         assert(bytes > 0);
1462         if(!load_from_checkpoint_)
1463         {
1464           fn->fillWeightBuffers(tBuf, buftype, bytes);
1465 #if 0
1466           if(lp)
1467             convert_f32_bf16((float*)ptr, (libxsmm_bfloat16*)lptr, lpbytes/sizeof(libxsmm_bfloat16), 0);
1468 #endif
1469         }
1470 
1471 #if 0
1472         if(solver_->getGlobalFlag())
1473           if(buftype == DIFF)
1474             if(data_type_ == FLOAT)
1475               fn->fillWeightMultipliers(lrptr, decptr, bytes/sizeof(float));
1476             else if(data_type_ == BF16)
1477               fn->fillWeightMultipliers(lrptr, decptr, bytes/sizeof(libxsmm_bfloat16));
1478 #endif
1479       }
1480       else if(tType == FCBIAS)
1481       {
1482         FCNode* fn = dynamic_cast<FCNode*>(t->getOwner());
1483         assert(bytes > 0);
1484         if(!load_from_checkpoint_)
1485           fn->fillBiasBuffers(tBuf, buftype, bytes);
1486 #if 0
1487         if(solver_->getGlobalFlag())
1488           if(buftype == DIFF)
1489             fn->fillBiasMultipliers(lrptr, decptr, bytes/sizeof(float));
1490 #endif
1491       }
1492       else if((tType == BNORMSCALE) || (tType == BNORMSHIFT))
1493       {
1494         if(nntype == "FusedConvBN")
1495         {
1496           FusedConvBNNode *fcbn = dynamic_cast<FusedConvBNNode*>(t->getOwner());
1497           assert(bytes > 0);
1498           if(!load_from_checkpoint_)
1499             fcbn->fillBuffer(tBuf, buftype, bytes);
1500 #if 0
1501           if(solver_->getGlobalFlag())
1502             if(buftype == DIFF)
1503               fcbn->fillBiasMultipliers(lrptr, decptr, bytes/sizeof(float));
1504 #endif
1505         }
1506         else if(nntype == "FusedBatchNorm")
1507         {
1508           FusedBNormNode* bn = dynamic_cast<FusedBNormNode*>(t->getOwner());
1509           assert(bytes > 0);
1510           if(!load_from_checkpoint_)
1511             bn->fillBuffer(tBuf, buftype, bytes);
1512 #if 0
1513           if(solver_->getGlobalFlag())
1514             if(buftype == DIFF)
1515               bn->fillBiasMultipliers(lrptr, decptr, bytes/sizeof(float));
1516 #endif
1517         }
1518       }
1519       else if((tType == BNORMMEAN) || (tType == BNORMVAR))
1520       {
1521         if(nntype == "FusedConvBN")
1522         {
1523           FusedConvBNNode *fcbn = dynamic_cast<FusedConvBNNode*>(t->getOwner());
1524           assert(bytes > 0);
1525           if(!load_from_checkpoint_)
1526             fcbn->fillBuffer(tBuf, buftype, bytes);
1527         }
1528         else if(nntype == "FusedBatchNorm")
1529         {
1530           FusedBNormNode* bn = dynamic_cast<FusedBNormNode*>(t->getOwner());
1531           assert(bytes > 0);
1532           if(!load_from_checkpoint_)
1533             bn->fillBuffer(tBuf, buftype, bytes);
1534         }
1535       }
1536     }
1537 
1538     if(bytes > 0)
1539     {
1540       ptr += bytes;
1541       if(lp)
1542         if(buftype == DATA)
1543           lptr += lpbytes;
1544 #ifndef BF16_MLSL
1545         else if(buftype == DIFF)
1546           lgptr += lpbytes;
1547 #endif
1548 
1549 #ifdef BF16_MLSL
1550       if(tenType == "WEIGHT" && buftype == DATA)
1551         offset += bytes/sizeof(float);
1552       else if(tenType == "WEIGHT" && buftype == DIFF)
1553       {
1554         if(data_type_ == FLOAT)
1555           offset += bytes/sizeof(float);
1556         else if(data_type_ == BF16)
1557           offset += bytes/sizeof(libxsmm_bfloat16);
1558       }
1559 #else
1560       if(tenType == "WEIGHT")
1561         offset += bytes/sizeof(float);
1562 #endif
1563       else if((tenType == "BIAS" && (buftype == DATA || buftype == DIFF)) || tenType == "STATS")
1564         offset += bytes/sizeof(float);
1565 
1566 #if 0
1567       if(solver_->getGlobalFlag())
1568       {
1569         if(tenType == "WEIGHT" && buftype == DIFF)
1570         {
1571           if(data_type_ == FLOAT)
1572           {
1573             lrptr += bytes/sizeof(float);
1574             decptr += bytes/sizeof(float);
1575           }
1576           else if(data_type_ == BF16)
1577           {
1578             lrptr += bytes/sizeof(libxsmm_bfloat16);
1579             decptr += bytes/sizeof(libxsmm_bfloat16);
1580           }
1581         }
1582         else if(tenType == "BIAS" && buftype == DIFF)
1583         {
1584           lrptr += bytes/sizeof(float);
1585           decptr += bytes/sizeof(float);
1586         }
1587       }
1588 #endif
1589 
1590       assert(ptr <= buf_ + s);
1591 
1592       // For canary checking
1593       if(ttp)
1594       {
1595         memset(ptr, CANARY, END_GUARD_BAND);
1596         can_ptr.push_back(bytes);
1597         assert(can_ptr.size() <= num_canaries);
1598       }
1599       if(ttp)
1600         ptr += END_GUARD_BAND;
1601     }
1602     assert(ptr <= buf_ + s);
1603 #if 0
1604     printf("ptr @ %p\n",ptr);
1605 #endif
1606   }
1607 
1608   if(tenType=="WEIGHT" && buftype==DATA)
1609   {
1610 #ifdef _OPENMP
1611 #pragma omp parallel
1612 #endif
1613     {
1614       int tid = omp_get_thread_num();
1615       int ntps = num_threads_/NUM_NUMA_NODES;
1616       int n = tid/ntps;
1617       int w = total_weights_;
1618       if(n != 0 && tid % ntps == 0)
1619       {
1620         float *wtptr = (float*)weight_buf_[n];
1621 
1622 #pragma omp simd
1623         for(int i=0; i<w; i++)
1624           wtptr[i] = ((float*)weight_buf_[0])[i];
1625       }
1626     }
1627 
1628     if(lp)
1629       convert_f32_bf16((float**)weight_buf_, (libxsmm_bfloat16**)lpweight_buf_, total_weights_);
1630   }
1631 
1632   if(tenType=="WEIGHT" && buftype==DIFF)
1633   {
1634     if(data_type_ == FLOAT)
1635     {
1636 #ifdef _OPENMP
1637 #pragma omp parallel
1638 #endif
1639       {
1640         int tid = omp_get_thread_num();
1641         int ntps = num_threads_/NUM_NUMA_NODES;
1642         int n = tid/ntps;
1643         int w = total_weights_;
1644         if(n != 0 && tid % ntps == 0)
1645         {
1646           float *wdiff = (float*)wdiff_buf_[n];
1647 
1648 #pragma omp simd
1649           for(int i=0; i<w; i++)
1650             wdiff[i] = ((float*)wdiff_buf_[0])[i];
1651         }
1652       }
1653     }
1654     else if(data_type_ == BF16)
1655     {
1656 #ifdef _OPENMP
1657 #pragma omp parallel
1658 #endif
1659       {
1660         int tid = omp_get_thread_num();
1661         int ntps = num_threads_/NUM_NUMA_NODES;
1662         int n = tid/ntps;
1663         int w = total_weights_;
1664         if(n != 0 && tid % ntps == 0)
1665         {
1666           libxsmm_bfloat16 *lpwdiff = (libxsmm_bfloat16*)lpwdiff_buf_[n];
1667           float *wdiff = (float*)wdiff_buf_[n];
1668 
1669 #pragma omp simd
1670           for(int i=0; i<w; i++)
1671           {
1672             lpwdiff[i] = ((libxsmm_bfloat16*)lpwdiff_buf_[0])[i];
1673             wdiff[i] = ((float*)wdiff_buf_[0])[i];
1674           }
1675         }
1676       }
1677     }
1678 
1679 #if 0
1680 #ifdef _OPENMP
1681 #pragma omp parallel
1682 #endif
1683     {
1684       int tid = omp_get_thread_num();
1685       int ntps = num_threads_/NUM_NUMA_NODES;
1686       int n = tid/ntps;
1687       int w = total_weights_;
1688       if(n != 0 && tid % ntps == 0)
1689       {
1690         float *lrp = (float*)wt_lr_mult_[n];
1691         float *dcp = (float*)wt_decay_mult_[n];
1692 
1693 #pragma omp simd
1694         for(int i=0; i<w; i++)
1695         {
1696           lrp[i] = ((float*)wt_lr_mult_[0])[i];
1697           dcp[i] = ((float*)wt_decay_mult_[0])[i];
1698         }
1699       }
1700     }
1701 #endif
1702   }
1703 
1704   if(tenType=="WEIGHT" && buftype==HISTORY)
1705   {
1706 #ifdef _OPENMP
1707 #pragma omp parallel
1708 #endif
1709     {
1710       int tid = omp_get_thread_num();
1711       int ntps = num_threads_/NUM_NUMA_NODES;
1712       int n = tid/ntps;
1713       int w = total_weights_;
1714       if(n != 0 && tid % ntps == 0)
1715       {
1716         float *inc = (float*)winc_buf_[n];
1717 
1718 #pragma omp simd
1719         for(int i=0; i<w; i++)
1720           inc[i] = ((float*)winc_buf_[0])[i];
1721       }
1722     }
1723   }
1724 
1725 #if 1
1726   if(tenType == "BIAS" && buftype == DATA)
1727   {
1728 #ifdef _OPENMP
1729 #pragma omp parallel
1730 #endif
1731     {
1732       int tid = omp_get_thread_num();
1733       int ntps = num_threads_/NUM_NUMA_NODES;
1734       int n = tid/ntps;
1735       int b = total_biases_;
1736 
1737       if(n != 0 && tid % ntps == 0)
1738       {
1739         float *bias = (float*)bias_buf_[n];
1740 
1741 #pragma omp simd
1742         for(int i=0; i<b; i++)
1743           bias[i] = ((float*)bias_buf_[0])[i];
1744       }
1745     }
1746   }
1747 
1748 
1749   if(tenType == "BIAS" && buftype == DIFF)
1750   {
1751 #ifdef _OPENMP
1752 #pragma omp parallel
1753 #endif
1754     {
1755       int tid = omp_get_thread_num();
1756       int ntps = num_threads_/NUM_NUMA_NODES;
1757       int n = tid/ntps;
1758       int b = total_biases_;
1759 
1760       if(n != 0 && tid % ntps == 0)
1761       {
1762         float *bidiff = (float*)bidiff_buf_[n];
1763 #if 0
1764         float *lrp = (float*)bias_lr_mult_[n];
1765         float *dcp = (float*)bias_decay_mult_[n];
1766 #endif
1767 
1768 #pragma omp simd
1769         for(int i=0; i<b; i++)
1770         {
1771           bidiff[i] = ((float*)bidiff_buf_[0])[i];
1772 #if 0
1773           lrp[i] = ((float*)bias_lr_mult_[0])[i];
1774           dcp[i] = ((float*)bias_decay_mult_[0])[i];
1775 #endif
1776         }
1777       }
1778     }
1779   }
1780 
1781   if(tenType == "BIAS" && buftype == HISTORY)
1782   {
1783 #ifdef _OPENMP
1784 #pragma omp parallel
1785 #endif
1786     {
1787       int tid = omp_get_thread_num();
1788       int ntps = num_threads_/NUM_NUMA_NODES;
1789       int n = tid/ntps;
1790       int b = total_biases_;
1791 
1792       if(n != 0 && tid % ntps == 0)
1793       {
1794         float *biinc = (float*)biinc_buf_[n];
1795 
1796 #pragma omp simd
1797         for(int i=0; i<b; i++)
1798           biinc[i] = ((float*)biinc_buf_[0])[i];
1799       }
1800     }
1801   }
1802 
1803   if(tenType == "STATS")
1804   {
1805 #ifdef _OPENMP
1806 #pragma omp parallel
1807 #endif
1808     {
1809       int tid = omp_get_thread_num();
1810       int ntps = num_threads_/NUM_NUMA_NODES;
1811       int n = tid/ntps;
1812       int b = total_biases_;
1813 
1814       if(n != 0 && tid % ntps == 0)
1815       {
1816         float *stats = (float*)stats_buf_[n];
1817 
1818 #pragma omp simd
1819         for(int i=0; i<b; i++)
1820           stats[i] = ((float*)stats_buf_[0])[i];
1821       }
1822     }
1823   }
1824 #endif
1825 }
1826 
insertSplitNodes(NTGParameter & p,NTGParameter * ps)1827 void MLEngine::insertSplitNodes(NTGParameter& p, NTGParameter* ps)
1828 {
1829   ps->CopyFrom(p);
1830   ps->clear_node();
1831 
1832   vector< pair<string, string> > top_names;
1833 
1834   for(int i=0; i<p.node_size(); i++)
1835   {
1836     const NodeParameter& np = p.node(i);
1837     string nn = np.name();
1838     for(int j=0; j<np.top_size(); j++)
1839       top_names.push_back(make_pair(np.top(j), nn));
1840   }
1841 
1842   std::multimap<std::string, NodeParameter> top_as_bot;
1843 
1844   for(int i=0; i < top_names.size(); i++)
1845   {
1846     pair<string, string> tn = top_names[i];
1847     for(int j=0; j < p.node_size(); j++)
1848     {
1849       const NodeParameter& np = p.node(j);
1850       string nn = p.node(j).name();
1851       if(nn.compare(tn.second) == 0) continue;
1852       for(int k=0; k < np.bottom_size(); k++)
1853       {
1854         std::string t = tn.first;
1855         if(t.compare(p.node(j).bottom(k)) == 0)
1856           top_as_bot.insert(make_pair(t, p.node(j)));
1857       }
1858     }
1859   }
1860 
1861   std::multimap<std::string, std::string> old_bottom;
1862   std::multimap<std::string, std::string> new_bottom;
1863 
1864   for(int i=0; i<p.node_size(); i++)
1865   {
1866     NodeParameter* np = ps->add_node();
1867     np->CopyFrom(p.node(i));
1868     string onn = np->name();
1869 
1870     for(int j=0; j<np->top_size(); j++)
1871     {
1872       string t = np->top(j);
1873       int split_count = top_as_bot.count(t);
1874       if(split_count > 1)
1875       {
1876         NodeParameter *snp = ps->add_node();
1877         snp->Clear();
1878         snp->add_bottom(t);
1879         string snn = t + "_" + onn + "_" + std::to_string(j) + "_split";
1880         snp->set_name(snn);
1881         snp->set_type("Split");
1882         if(t.compare("label") == 0)
1883           snp->set_propagate_down(false);
1884 
1885         std::multimap<string, NodeParameter>::iterator it;
1886         int k = 0;
1887         for(it=top_as_bot.equal_range(t).first; it != top_as_bot.equal_range(t).second; it++)
1888         {
1889           NodeParameter onp = (*it).second;
1890           string nn = onp.name();
1891 
1892           string stn = t + "_" + nn + "_" + std::to_string(j) + "_split_" + std::to_string(k);
1893           snp->add_top(stn);
1894           k++;
1895 
1896           for(int l=0; l<onp.bottom_size(); l++)
1897           {
1898             if(onp.bottom(l) == t)
1899             {
1900               old_bottom.insert(make_pair(t, nn));
1901               new_bottom.insert(make_pair(nn, stn));
1902             }
1903           }
1904         }
1905       }
1906     }
1907   }
1908 
1909   std::multimap<std::string, std::string>::iterator it1;
1910   std::multimap<std::string, std::string>::iterator it2;
1911   for(int i=0; i<ps->node_size(); i++)
1912   {
1913     NodeParameter* mn = ps->mutable_node(i);
1914     if(mn->type().compare("Split") == 0) continue;
1915     for(int j=0; j<mn->bottom_size(); j++)
1916     {
1917       string t = mn->bottom(j);
1918       it1 = old_bottom.find(t);
1919       if(it1 == old_bottom.end()) continue;
1920 
1921       for(it1=old_bottom.equal_range(t).first; it1 != old_bottom.equal_range(t).second; it1++)
1922         if(mn->name() == (*it1).second) break;
1923 
1924       assert(it1 != old_bottom.end());
1925       string s = (*it1).second;
1926       for(it2=new_bottom.equal_range(s).first; it2 != new_bottom.equal_range(s).second; it2++)
1927       {
1928         string v = (*it2).second;
1929         if(v.find(mn->bottom(j)) != v.npos)
1930           mn->set_bottom(j, v);
1931       }
1932     }
1933   }
1934 }
1935 
create(int mode,string ntgConfig,string solverConfig)1936 void MLEngine::create(int mode, string ntgConfig, string solverConfig)
1937 {
1938   bool parsed = parseMLConfig(ntgConfig, &ntgparam_);
1939   if(!parsed) exit(-1);
1940 
1941   if(!solverConfig.empty())
1942   {
1943     parsed = parseSolverConfig(solverConfig, &sparam_);
1944     if(!parsed) exit(-1);
1945 
1946     num_epochs_ = sparam_.max_epochs();
1947     current_epoch_ = 0;
1948     current_batch_ = 0;
1949     load_from_checkpoint_ = sparam_.load_checkpoint();
1950     checkpoint_dir_ = sparam_.checkpoint_dir();
1951     checkpoint_format_ = sparam_.checkpoint_format();
1952     data_type_ = sparam_.data_type();
1953   }
1954 
1955 #ifdef _OPENMP
1956   num_threads_ = omp_get_max_threads();
1957 #else
1958   num_threads_ = 1;
1959 #endif
1960 
1961   printf("Using %d threads\n",num_threads_);
1962 
1963 #ifdef USE_MLSL
1964   global_node_id_ = MLSL::Environment::GetEnv().GetProcessIdx();
1965   num_machines_ = MLSL::Environment::GetEnv().GetProcessCount();
1966   data_parallelism = NULL;
1967   if(mode == TRAIN || mode == VAL)
1968     session_ = MLSL::Environment::GetEnv().CreateSession(MLSL::PT_TRAIN);
1969   else
1970     session_ = MLSL::Environment::GetEnv().CreateSession(MLSL::PT_TEST);
1971 #else
1972   global_node_id_ = 0;
1973   num_machines_ = 1;
1974 #endif
1975 
1976   // if no training mode in config, then set inferenceOnly_ to true
1977   inferenceOnly_ = (mode == TEST);
1978 
1979   // Initialize solver node
1980   int ni = find_in_nodeTypeList("Solver");
1981   solverParams_ = parseSolverParams(&sparam_);
1982   solver_ = new SolverNode(solverParams_, this);
1983 
1984   /*************************************************************************************/
1985   /*** Create a global tensor to hold scratch memory needed by Conv layers (LIBXSMM) ***/
1986   /*************************************************************************************/
1987   tenScratch_ = new Tensor("scratch");
1988   tenScratchBuf_ = tenScratch_->getBuf(DATA);
1989   tenScratchBuf_->setBufferPtr(scratch);
1990 
1991   NTGParameter split_ntgparam;
1992 
1993   insertSplitNodes(ntgparam_, &split_ntgparam);
1994   if(global_node_id_ == 0)
1995     split_ntgparam.PrintDebugString();
1996 
1997   int numNodes = split_ntgparam.node_size();
1998 
1999   for(int i=0; i<numNodes; i++)
2000   {
2001     // get name and type of each node
2002     // call parse and create node functions based on type
2003     // find member of TypeList
2004     NodeParameter np = split_ntgparam.node(i);
2005     string ntype = np.type();
2006 
2007 #ifdef DEBUG
2008     printf("node type %s\n",ntype.c_str());
2009 #endif
2010     int j = find_in_nodeTypeList(ntype);
2011 
2012     MLParams *p = nodeTypes[j].parse(&np);
2013     MLNode *node = nodeTypes[j].create(p, this);
2014     ntg_.push_back(node);
2015 #ifdef USE_MLSL
2016     if(ntype.find("Data") != ntype.npos)
2017       data_parallelism = MLSL::Environment::GetEnv().CreateDistribution(num_machines_, 1);
2018 #endif
2019 
2020   }
2021 
2022   // We assert that the first node in the topology be a data node. Graph creation starts from data node
2023   NNNode* dnode = dynamic_cast<NNNode*>(ntg_[0]);
2024   assert(dnode != NULL);
2025 
2026   string first = dnode->getNodeType();
2027 #ifdef DEBUG
2028   printf("first node type %s\n",first.c_str());
2029 #endif
2030   assert(first.find("Data") != first.npos);
2031 
2032   // Create the neural network graph for training or testing mode
2033   dnode->createNNGraph(mode);
2034 
2035   // Forward Pass Binning.
2036   // Look for tasks attached to nodes with no successors. Add them to the Execution Task Graph (etg) first.
2037   for(int i=numNodes-1; i>0; i--)
2038   {
2039     NNNode *nn = dynamic_cast<NNNode*>(ntg_[i]);
2040     Task* t = nn->getBasicTask(BASIC_TASK_FORW);
2041 
2042     if(nn->getNumNextNodes() == 0)
2043     {
2044       etg_[mode].push_back(t);
2045 #ifndef NDEBUG
2046       printf("FP task %p (node %s), bin %d pushed to etg_\n",t, nn->getNodeName().c_str(), t->getMaxBin());
2047 #endif
2048     }
2049   }
2050 
2051   // Assign bins to tasks based on their dependencies. Tasks with lower bin number must
2052   // execute before those with higher bin number. Tasks with same bin number can execute in parallel
2053   // Ensure no duplicate tasks in etg
2054   create_schedule(mode);
2055   optimize_schedule(mode);
2056 
2057   if(mode == TRAIN)
2058   {
2059     for(auto it = etg_[mode].begin(); it != etg_[mode].end(); it++)
2060     {
2061       Task *t = *it;
2062       if(t->getBasicTaskId() == BASIC_TASK_FORW)
2063         etg_[VAL].push_back(t);
2064       else
2065         break;
2066     }
2067   }
2068 
2069 #ifdef DEBUG
2070   for(auto it=etg_[mode].begin(); it != etg_[mode].end(); it++)
2071   {
2072     Task* t = (*it);
2073     string s = dynamic_cast<NNNode*>(t->getNode())->getNodeName();
2074     if(t->getBasicTaskId() == BASIC_TASK_FORW)
2075       printf("FP Task %p in node %s at bin %d\n",t, s.c_str(), t->getMaxBin());
2076     else if(t->getBasicTaskId() == BASIC_TASK_BACK)
2077       printf("BP  Task %p in node %s at bin %d\n",t, s.c_str(), t->getMaxBin());
2078     else if(t->getBasicTaskId() == BASIC_TASK_WGRAD)
2079       printf("WG Task %p in node %s at bin %d\n",t, s.c_str(), t->getMaxBin());
2080     else
2081       printf("SOLVER Task %p in node %s at bin %d\n",t, s.c_str(), t->getMaxBin());
2082   }
2083 #endif
2084 
2085   if(mode == TRAIN)
2086     printf("Training schedule has %u tasks\n",(unsigned int)etg_[mode].size());
2087   else
2088     printf("Testing schedule has %u tasks\n",(unsigned int)etg_[mode].size());
2089 
2090 
2091   /*** Allocate memory and set pointers for INPUT and LABEL buffers ***/
2092   /**********************************************************************/
2093   long long int total_input_size;
2094   long long int max_fwd_buffer_size=0;
2095 
2096   allocate_memory("INPUT", inTList_, DATA, input_can_ptr, &ic, &total_input_size);
2097   if(global_node_id_ == 0)
2098     printf("Total input memory allocated %lld bytes\n", total_input_size);
2099 
2100   /**********************************************************************/
2101   /*** Allocate memory and set pointers for FORWARD ACTIVATION buffer ***/
2102   /**********************************************************************/
2103   long long int total_fact_size;
2104   allocate_memory("FACT", outTList_, DATA, fact_can_ptr, &fac, &total_fact_size);
2105   if(global_node_id_ == 0)
2106     printf("Total forward activation memory allocated %lld bytes\n", total_fact_size);
2107 
2108   /***********************************************************/
2109   /*** Allocate memory and set pointers for WEIGHTS buffer ***/
2110   /***********************************************************/
2111   long long int total_weight_size;
2112   allocate_memory("WEIGHT", wTList_, DATA, wt_can_ptr, &wtc, &total_weight_size);
2113   if(global_node_id_ == 0)
2114     printf("Total weights memory allocated %lld bytes\n", total_weight_size);
2115 
2116   /***********************************************************/
2117   /*** Allocate memory and set pointers for BIASES buffer ***/
2118   /***********************************************************/
2119   long long int total_bias_size;
2120   allocate_memory("BIAS", biasTList_, DATA, bias_can_ptr, &bic, &total_bias_size);
2121   if(global_node_id_ == 0)
2122     printf("Total bias memory allocated %lld bytes\n", total_bias_size);
2123 
2124   /***********************************************************/
2125   /*** Allocate memory and set pointers for STATS buffer ***/
2126   /***********************************************************/
2127   long long int total_stats_size;
2128   allocate_memory("STATS", statsTList_, DATA, stats_can_ptr, &sic, &total_stats_size);
2129   if(global_node_id_ == 0)
2130     printf("Total stats memory allocated %lld bytes\n", total_stats_size);
2131 
2132   // Required only for training
2133   long long int total_bp_size;
2134   if(!inferenceOnly_)
2135   {
2136     /***********************************************************************/
2137     /*** Allocate memory and set pointers for BACKWARD ACTIVATION buffer ***/
2138     /***********************************************************************/
2139 #if !defined(USE_OPTBP_ALLOC)
2140     long long int total_bact_size;
2141     allocate_memory("BACT", outTList_, DIFF, bact_can_ptr, &bac, &total_bact_size);
2142     if(global_node_id_ == 0)
2143       printf("Total backward activation memory allocated %lld bytes\n", total_bact_size);
2144 #else
2145     long long int total_bact_size = NDIFFS * max_fwd_buffer_size;
2146     allocate_gradient_tensor(outTList_, DIFF, NDIFFS, max_fwd_buffer_size);
2147     if(global_node_id_ == 0)
2148       printf("Total backward activation memory allocated %lld bytes\n", total_bact_size);
2149 #endif
2150 
2151     /********************************************************************/
2152     /*** Allocate memory and set pointers for WEIGHT GRADIENTS buffer ***/
2153     /********************************************************************/
2154     long long int total_wdiff_size;
2155     allocate_memory("WEIGHT", wTList_, DIFF, wdiff_can_ptr, &wdc, &total_wdiff_size);
2156     if(global_node_id_ == 0)
2157       printf("Total weight gradient memory allocated %lld bytes\n", total_wdiff_size);
2158 
2159     /*********************************************************************/
2160     /*** Allocate memory and set pointers for WEIGHT INCREMENTS buffer ***/
2161     /*********************************************************************/
2162     long long int total_winc_size;
2163     allocate_memory("WEIGHT", wTList_, HISTORY, winc_can_ptr, &wic, &total_winc_size);
2164     if(global_node_id_ == 0)
2165       printf("Total weight increment memory allocated %lld bytes\n", total_winc_size);
2166 
2167     /********************************************************************/
2168     /*** Allocate memory and set pointers for BIAS GRADIENTS buffer ***/
2169     /********************************************************************/
2170     long long int total_bidiff_size;
2171     allocate_memory("BIAS", biasTList_, DIFF, bidiff_can_ptr, &bidc, &total_bidiff_size);
2172     if(global_node_id_ == 0)
2173       printf("Total bias gradient memory allocated %lld bytes\n", total_bidiff_size);
2174 
2175     /*********************************************************************/
2176     /*** Allocate memory and set pointers for BIAS INCREMENTS buffer ***/
2177     /*********************************************************************/
2178     long long int total_biinc_size;
2179     allocate_memory("BIAS", biasTList_, HISTORY, biinc_can_ptr, &biic, &total_biinc_size);
2180     if(global_node_id_ == 0)
2181       printf("Total bias increment memory allocated %lld bytes\n", total_biinc_size);
2182 
2183     total_bp_size = total_bact_size + total_wdiff_size + total_winc_size + total_bidiff_size + total_biinc_size;
2184   }
2185 
2186   long long int total_memory = total_input_size + total_fact_size + total_weight_size + total_bias_size + total_bp_size;
2187   if(global_node_id_ == 0)
2188     printf("Total tensor memory = %lld\n",total_memory);
2189 }
2190