1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved. *
3 * This file is part of the LIBXSMM library. *
4 * *
5 * For information on the license, see the LICENSE file. *
6 * Further information: https://github.com/hfp/libxsmm/ *
7 * SPDX-License-Identifier: BSD-3-Clause *
8 ******************************************************************************/
9 /* Sasikanth Avancha, Dhiraj Kalamkar (Intel Corp.)
10 ******************************************************************************/
11
12
13 #include <map>
14 #include "assert.h"
15 #include "proto/gxm.pb.h"
16 #include "Node.hpp"
17 #include "Engine.hpp"
18 #include "Conv.hpp"
19 #include "FullyConnected.hpp"
20 #include "FusedBNorm.hpp"
21 #include "FusedConvBN.hpp"
22 #include "DummyData.hpp"
23 #include "TypeList.hpp"
24
25 #include "unistd.h"
26 #include "limits.h"
27
28 #define VLEN 16
29
30 using namespace std;
31 using namespace gxm;
32
33 int iter=0;
34
compare_task_bins(Task * first,Task * second)35 bool compare_task_bins(Task* first, Task* second)
36 {
37 return (first->getMaxBin() < second->getMinBin());
38 }
39
create_schedule(int mode)40 void MLEngine::create_schedule(int mode)
41 {
42 for(auto it=etg_[mode].begin(); it != etg_[mode].end(); it++)
43 {
44 Task* t = *it;
45 vector<Task*> tp = t->getBackDepTasks();
46 for(int i=0; i<tp.size(); i++) {
47 string s = dynamic_cast<NNNode*>(tp[i]->getNode())->getNodeName();
48
49 if(tp[i]->getBasicTaskId() == BASIC_TASK_FORW) {
50 int maxbin = tp[i]->getMaxBin();
51 if((maxbin == 0) || (maxbin > t->getMinBin()-1))
52 {
53 tp[i]->setMinBin(t->getMaxBin() - 1);
54 tp[i]->setMaxBin(t->getMaxBin() - 1);
55 etg_[mode].push_back(tp[i]);
56 #ifdef DEBUG
57 printf("FP task %p (node %s), with bin %d pushed to etg_\n",tp[i], s.c_str(), tp[i]->getMaxBin());
58 #endif
59 }
60 }
61 }
62 }
63
64 if(mode == TRAIN)
65 {
66 for(auto it=etg_[mode].begin(); it != etg_[mode].end(); it++)
67 {
68 Task* t = *it;
69 vector<Task*> tp = t->getForwDepTasks();
70 for(int i=0; i<tp.size(); i++)
71 {
72 string s = dynamic_cast<NNNode*>(tp[i]->getNode())->getNodeName();
73
74 if(tp[i]->getBasicTaskId() != BASIC_TASK_FORW)
75 {
76 int maxbin = tp[i]->getMaxBin();
77 if((maxbin == 0) || (maxbin < t->getMinBin()+1))
78 {
79 tp[i]->setMinBin(t->getMaxBin() + 1);
80 tp[i]->setMaxBin(t->getMaxBin() + 1);
81 etg_[mode].push_back(tp[i]);
82 #ifdef DEBUG
83 if(tp[i]->getBasicTaskId() == BASIC_TASK_BACK)
84 printf("BP task %p (node %s), with bin %d pushed to etg_\n",tp[i], s.c_str(), tp[i]->getMaxBin());
85 else if(tp[i]->getBasicTaskId() == BASIC_TASK_WGRAD)
86 printf("WU task %p (node %s), with bin %d pushed to etg_\n",tp[i], s.c_str(), tp[i]->getMaxBin());
87 else if(tp[i]->getBasicTaskId() == BASIC_TASK_SOLVE)
88 printf("SOLVE task %p (node %s), with bin %d pushed to etg_\n",tp[i], s.c_str(), tp[i]->getMaxBin());
89 #endif
90 }
91 }
92 }
93 }
94 }
95 }
96
find_in_nodeTypeList(string name)97 int MLEngine::find_in_nodeTypeList(string name)
98 {
99 for(int i=0; i<numTypes; i++)
100 if(nodeTypes[i].typeName.compare(name) == 0)
101 return i;
102 return -1;
103 }
104
register_tensor(string name,int type,Tensor * t)105 bool MLEngine::register_tensor(string name, int type, Tensor* t)
106 {
107 TensorPair tp;
108 tp.name = name;
109 tp.t = t;
110
111 Iter it;
112
113 switch(type)
114 {
115 case INPUT:
116 case LABEL:
117 it = inTList_.insert(inTList_.end(), tp);
118 inTensorMap_[name] = it;
119 break;
120
121 case ACT:
122 it = outTList_.insert(outTList_.end(), tp);
123 outTensorMap_[name] = it;
124 break;
125
126 case CONVWEIGHT:
127 case FCWEIGHT:
128 it = wTList_.insert(wTList_.end(), tp);
129 weightTensorMap_[name] = it;
130 break;
131
132 case CONVBIAS:
133 case FCBIAS:
134 case BNORMSCALE:
135 case BNORMSHIFT:
136 it = biasTList_.insert(biasTList_.end(), tp);
137 biasTensorMap_[name] = it;
138 break;
139
140 case BNORMMEAN:
141 case BNORMVAR:
142 it = statsTList_.insert(statsTList_.end(), tp);
143 statsTensorMap_[name] = it;
144 break;
145 }
146 return true;
147 }
148
get_tensor(string name,int type)149 Tensor* MLEngine::get_tensor(string name, int type)
150 {
151 Iter it = defTList_.end();
152
153 switch(type)
154 {
155 case INPUT:
156 case LABEL:
157 it = inTensorMap_[name];
158 break;
159
160 case ACT:
161 it = outTensorMap_[name];
162 break;
163
164 case CONVWEIGHT:
165 case FCWEIGHT:
166 it = weightTensorMap_[name];
167 break;
168
169 case CONVBIAS:
170 case FCBIAS:
171 case BNORMSCALE:
172 case BNORMSHIFT:
173 it = biasTensorMap_[name];
174 break;
175
176 case BNORMMEAN:
177 case BNORMVAR:
178 it = statsTensorMap_[name];
179 break;
180 }
181
182 if(it == defTList_.end())
183 return NULL;
184
185 TensorPair tp = *it;
186 return tp.t;
187 }
188
optimize_schedule(int mode)189 void MLEngine::optimize_schedule(int mode)
190 {
191 etg_[mode].sort(compare_task_bins);
192 etg_[mode].erase(std::stable_partition(etg_[mode].begin(), etg_[mode].end(), dupChecker_()), etg_[mode].end());
193 etg_[mode].unique();
194 }
195
clear_history(TensorList L)196 void MLEngine::clear_history(TensorList L)
197 {
198 int buftype = HISTORY;
199
200 for(Iter it=L.begin(); it != L.end(); it++)
201 {
202 Tensor* t = it->t;
203 TensorBuf *tBuf;
204 bool found = false;
205 for(int index=0; index<t->getNumDataBuffers(); index++)
206 {
207 tBuf = t->getBuf(index);
208 if(tBuf->getBufferType() == buftype)
209 {
210 found = true;
211 break;
212 }
213 }
214 if(!found) continue;
215
216 long long int bytes = tBuf->getBufferSize();
217 int dtype = tBuf->getDataType();
218
219 float *fp = (float*)(tBuf->getBuffer());
220 #ifdef _OPENMP
221 #pragma omp parallel for
222 #endif
223 for(int i=0; i<bytes/sizeof(float); i++)
224 fp[i] = 0.f;
225 }
226 }
227
checkpoint(TensorList L,int buftype)228 void MLEngine::checkpoint(TensorList L, int buftype)
229 {
230 for(Iter it=L.begin(); it != L.end(); it++)
231 {
232 Tensor* t = it->t;
233 TensorBuf *tBuf;
234 bool found=false;
235
236 for(int index=0; index<t->getNumDataBuffers(); index++)
237 {
238 tBuf = t->getBuf(index);
239 if(tBuf->getBufferType() == buftype)
240 {
241 found = true;
242 break;
243 }
244 }
245 if(!found) continue;
246
247 int tenType = t->getType();
248 string tn = t->getTensorName();
249 string n = checkpoint_dir_ + "/" + tn;
250 if(buftype == HISTORY)
251 n = n + "_history";
252 else if(buftype == DIFF)
253 n = n + "_grad";
254
255 string nntype = dynamic_cast<NNNode*>(t->getOwner())->getNodeType();
256
257 if(current_epoch_ == 30 || current_epoch_ == 60 || current_epoch_ == 80)
258 {
259 if(tenType == ACT)
260 {
261 n = checkpoint_dir_ + to_string(current_epoch_) + "/" + tn;
262 if(tn.find("bn") != tn.npos)
263 {
264 if(nntype == "FusedBatchNorm")
265 {
266 FusedBNormNode* bn = dynamic_cast<FusedBNormNode*>(t->getOwner());
267 bn->Checkpoint(tBuf, n, checkpoint_format_);
268 }
269 else if(nntype == "FusedConvBN")
270 {
271 FusedConvBNNode* fcbn = dynamic_cast<FusedConvBNNode*>(t->getOwner());
272 fcbn->Checkpoint(tBuf, n, checkpoint_format_);
273 }
274 }
275 }
276 }
277
278 if((tenType == CONVWEIGHT) || (tenType == CONVBIAS))
279 {
280 if(nntype == "Convolution")
281 {
282 ConvNode* cn = dynamic_cast<ConvNode*>(t->getOwner());
283 cn->Checkpoint(tBuf, n, checkpoint_format_);
284 if(current_epoch_ == 30 || current_epoch_ == 60 || current_epoch_ == 80)
285 {
286 n = checkpoint_dir_ + to_string(current_epoch_) + "/" + tn;
287 if(buftype == HISTORY)
288 n = n + "_history";
289 else if(buftype == DIFF)
290 n = n + "_diff";
291 cn->Checkpoint(tBuf, n, checkpoint_format_);
292 }
293 }
294 else if(nntype == "FusedConvBN")
295 {
296 FusedConvBNNode* fcbn = dynamic_cast<FusedConvBNNode*>(t->getOwner());
297 fcbn->Checkpoint(tBuf, n, checkpoint_format_);
298 if(current_epoch_ == 30 || current_epoch_ == 60 || current_epoch_ == 80)
299 {
300 n = checkpoint_dir_ + to_string(current_epoch_) + "/" + tn;
301 if(buftype == HISTORY)
302 n = n + "_history";
303 else if(buftype == DIFF)
304 n = n + "_grad";
305 fcbn->Checkpoint(tBuf, n, checkpoint_format_);
306 }
307 }
308 }
309 else if((tenType == FCWEIGHT) || (tenType == FCBIAS))
310 {
311 FCNode* fn = dynamic_cast<FCNode*>(t->getOwner());
312 fn->Checkpoint(tBuf, n, checkpoint_format_);
313 if(current_epoch_ == 30 || current_epoch_ == 60 || current_epoch_ == 80)
314 {
315 n = checkpoint_dir_ + to_string(current_epoch_) + "/" + tn;
316 if(buftype == HISTORY)
317 n = n + "_history";
318 else if(buftype == DIFF)
319 n = n + "_grad";
320 fn->Checkpoint(tBuf, n, checkpoint_format_);
321 }
322 }
323 else if((tenType == BNORMSCALE) || (tenType == BNORMSHIFT) || (tenType == BNORMMEAN) || (tenType == BNORMVAR))
324 {
325 if(nntype == "FusedBatchNorm")
326 {
327 FusedBNormNode* bn = dynamic_cast<FusedBNormNode*>(t->getOwner());
328 bn->Checkpoint(tBuf, n, checkpoint_format_);
329 if(current_epoch_ == 30 || current_epoch_ == 60 || current_epoch_ == 80)
330 {
331 n = checkpoint_dir_ + to_string(current_epoch_) + "/" + tn;
332 if(buftype == HISTORY)
333 n = n + "_history";
334 else if(buftype == DIFF)
335 n = n + "_grad";
336 bn->Checkpoint(tBuf, n, checkpoint_format_);
337 }
338 }
339 else if(nntype == "FusedConvBN")
340 {
341 FusedConvBNNode* fcbn = dynamic_cast<FusedConvBNNode*>(t->getOwner());
342 fcbn->Checkpoint(tBuf, n, checkpoint_format_);
343 if(current_epoch_ == 30 || current_epoch_ == 60 || current_epoch_ == 80)
344 {
345 n = checkpoint_dir_ + to_string(current_epoch_) + "/" + tn;
346 if(buftype == HISTORY)
347 n = n + "_history";
348 else if(buftype == DIFF)
349 n = n + "_grad";
350 fcbn->Checkpoint(tBuf, n, checkpoint_format_);
351 }
352 }
353 }
354 }
355 }
356
read_checkpoint_file(TensorBuf * tBuf,string filename,string format)357 void MLEngine::read_checkpoint_file(TensorBuf* tBuf, string filename, string format)
358 {
359 long long int bytes = tBuf->getBufferSize();
360 int dtype = tBuf->getDataType();
361
362 void* ptr;
363 ptr = tBuf->getBuffer();
364
365 FILE* f;
366 if(format == "binary")
367 {
368 f = fopen(filename.c_str(), "rb");
369 assert(f != NULL);
370 size_t b = fread(ptr, 1, bytes, f);
371 assert((long long int)b == bytes);
372 }
373 else
374 {
375 printf("Reading from %s\n",filename.c_str());
376 f = fopen(filename.c_str(), "r");
377 assert(f != NULL);
378 if(dtype == DT_FLOAT)
379 {
380 float* p = (float*)ptr;
381 for(int i=0; i < bytes/sizeof(float); i++)
382 fscanf(f, "%f", &p[i]);
383 }
384 }
385 fclose(f);
386
387 if(data_type_ == BF16 && (filename.find("wt") != filename.npos))
388 if(filename.find("history") == filename.npos)
389 convert_f32_bf16((float*)ptr, (libxsmm_bfloat16*)tBuf->getLPBuffer(), bytes/sizeof(float), 0);
390
391 }
392
load_checkpoint(TensorList L,int buftype,string format)393 void MLEngine::load_checkpoint(TensorList L, int buftype, string format)
394 {
395 TensorBuf* tBuf;
396
397 for(Iter it=L.begin(); it != L.end(); it++)
398 {
399 Tensor* t = it->t;
400 int tenType = t->getType();
401 if((tenType != CONVWEIGHT) && (tenType != CONVBIAS) && (tenType != FCWEIGHT) && (tenType != FCBIAS))
402 if((tenType != BNORMSCALE) && (tenType != BNORMSHIFT) && (tenType != BNORMMEAN) && (tenType != BNORMVAR))
403 continue;
404
405 bool found = false;
406 for(int index=0; index<t->getNumDataBuffers(); index++)
407 {
408 tBuf = t->getBuf(index);
409 if(tBuf->getBufferType() == buftype)
410 {
411 found = true;
412 break;
413 }
414 }
415 if(!found) continue;
416
417 string n = checkpoint_dir_ + "/" + t->getTensorName();
418
419 if(buftype == HISTORY)
420 n = n + "_history";
421
422 size_t pos;
423 while((pos = n.find("/", 10)) != n.npos)
424 n.replace(pos, 1, 1, '_');
425 read_checkpoint_file(tBuf, n, format);
426 }
427 }
428
canary_check(void * ptr,vector<int> & cp,int nc)429 void MLEngine::canary_check(void* ptr, vector<int>& cp, int nc)
430 {
431 if(ptr == NULL)
432 {
433 printf("FATAL: NULL pointer to buffer\n");
434 //exit(1);
435 }
436
437 int *p = (int*)ptr;
438 for(int i=0; i<START_GUARD_BAND/sizeof(int); i++)
439 {
440 // printf("p[%d] = %x\n",i, p[i]);
441 if(p[i] != 0x7f7f7f7f)
442 {
443 printf("Fatal: canary value overwritten at %d in buffer at %p\n",i, ptr);
444 //exit(1);
445 }
446 }
447
448 void *vp = (void*)(ptr + START_GUARD_BAND);
449
450 for(int i=0; i<nc; i++)
451 {
452 int next = cp[i];
453 vp = (void*)(vp + next);
454 int *pp = (int*)vp;
455 for(int j=0; j<END_GUARD_BAND/sizeof(int); j++)
456 {
457 // printf("pp[%d] = %x\n",j, pp[j]);
458 if(pp[j] != 0x7f7f7f7f)
459 {
460 printf("Fatal: canary value overwritten at %d in buffer at %p\n",j,pp);
461 //exit(1);
462 }
463 }
464 vp += END_GUARD_BAND;
465 }
466 }
467
waitForComms(string tenType)468 void MLEngine:: waitForComms(string tenType)
469 {
470 #ifdef USE_MLSL
471 if(tenType=="WEIGHT")
472 {
473 if(!wtgrad_comms_vec.empty())
474 {
475 for(int i=0; i<wtgrad_comms_vec.size(); i++)
476 wtgrad_comms_vec[i]->GetParameterSet(0)->WaitGradientComm();
477 }
478 }
479 else if(tenType=="BIAS")
480 {
481 if(!bias_grad_comms_vec.empty())
482 {
483 for(int i=0; i<bias_grad_comms_vec.size(); i++)
484 {
485 bias_grad_comms_vec[i]->GetParameterSet(0)->WaitGradientComm();
486 bias_grad_comms_vec[i]->GetParameterSet(1)->WaitGradientComm();
487 bias_grad_comms_vec[i]->GetParameterSet(2)->WaitGradientComm();
488 bias_grad_comms_vec[i]->GetParameterSet(3)->WaitGradientComm();
489 }
490 }
491 }
492 else if(tenType=="COMBO")
493 {
494 if(!combo_grad_comms_vec.empty())
495 {
496 for(int i=0; i<combo_grad_comms_vec.size(); i++)
497 {
498 combo_grad_comms_vec[i]->GetParameterSet(0)->WaitGradientComm();
499 combo_grad_comms_vec[i]->GetParameterSet(1)->WaitGradientComm();
500 combo_grad_comms_vec[i]->GetParameterSet(2)->WaitGradientComm();
501 combo_grad_comms_vec[i]->GetParameterSet(3)->WaitGradientComm();
502 combo_grad_comms_vec[i]->GetParameterSet(4)->WaitGradientComm();
503 }
504 }
505 }
506 #endif
507 }
508
run(int mode)509 void MLEngine::run(int mode)
510 {
511 if(mode == TRAIN)
512 {
513 if(load_from_checkpoint_)
514 {
515 FILE *f = fopen("checkpoint", "r");
516 if(f != NULL)
517 {
518 fscanf(f, "%d %f %f\n",¤t_epoch_, &lr_, &scf_);
519 fclose(f);
520 }
521 else
522 printf("No checkpoint state file to read\n");
523
524 if(current_epoch_ != num_epochs_ - 1)
525 current_epoch_++;
526 load_checkpoint(wTList_, DATA, checkpoint_format_);
527 load_checkpoint(wTList_, HISTORY, checkpoint_format_);
528 load_checkpoint(biasTList_, DATA, checkpoint_format_);
529 load_checkpoint(biasTList_, HISTORY, checkpoint_format_);
530 load_checkpoint(statsTList_, DATA, checkpoint_format_);
531
532 #ifdef _OPENMP
533 #pragma omp parallel
534 #endif
535 {
536 int tid = omp_get_thread_num();
537 int ntps = num_threads_/NUM_NUMA_NODES;
538 int n = tid/ntps;
539 int w = total_weights_;
540 int b = total_biases_;
541
542 if(n != 0 && tid % ntps == 0)
543 {
544 float *wptr = (float*)weight_buf_[n];
545 #if 1
546 float *bptr = (float*)bias_buf_[n];
547 float *sptr = (float*)stats_buf_[n];
548 #endif
549
550 #pragma omp simd
551 for(int i=0; i<w; i++)
552 wptr[i] = ((float*)weight_buf_[0])[i];
553
554 #if 1
555 #pragma omp simd
556 for(int i=0; i<b; i++)
557 {
558 bptr[i] = ((float*)bias_buf_[0])[i];
559 sptr[i] = ((float*)stats_buf_[0])[i];
560 }
561 #endif
562 if(lpweight_buf_[0] != NULL)
563 {
564 libxsmm_bfloat16 *lwptr = (libxsmm_bfloat16*)lpweight_buf_[n];
565 #pragma omp simd
566 for(int i=0; i<w; i++)
567 lwptr[i] = ((libxsmm_bfloat16*)lpweight_buf_[0])[i];
568 }
569 }
570 }
571 load_from_checkpoint_ = false;
572 }
573
574 fflush(stdout);
575
576 #ifdef USE_MLSL
577 data_parallelism->Barrier(MLSL::GT_DATA);
578 #endif
579
580 // current_epoch_ is set in create() function or by checkpoint code above
581 for(; current_epoch_ < num_epochs_; current_epoch_++)
582 {
583 // Tell data node that it should use training data
584 exec_mode_ = TRAIN;
585 if(global_node_id_ == 0)
586 {
587 printf("===========================================\n");
588 printf("TRAIN mode, epoch %d, training batches %d\n", current_epoch_, num_train_batches_);
589 printf("===========================================\n");
590 }
591
592 // Run training network for an epoch
593 struct timeval tvs, tve, tvts, tvte, tvis, tvie;
594 double fbtime, runtime = 0;
595
596 for(; current_batch_<num_train_batches_; current_batch_++)
597 {
598 if(global_node_id_ == 0 && current_batch_ % 100 == 0)
599 printf("Executing batch number %d\n",current_batch_);
600
601 gettimeofday(&tvs, NULL);
602
603 for(auto it = etg_[TRAIN].begin(); it != etg_[TRAIN].end(); it++)
604 {
605 #ifdef TIMING
606 gettimeofday(&tvts, NULL);
607 #endif
608 (*it)->invoke();
609
610 #ifdef TIMING
611 gettimeofday(&tvte, NULL);
612 double tasktime = (tvte.tv_sec*1e6 + tvte.tv_usec) - (tvts.tv_sec*1e6 + tvts.tv_usec);
613 NNNode *nn = dynamic_cast<NNNode*>((*it)->getNode());
614 if(global_node_id_ == 0)
615 printf("Node %s (task %d) time = %f ms\n",nn->getNodeName().c_str(), (*it)->getBasicTaskId(), tasktime/1000);
616 #endif
617 }
618
619 if(solver_->getGlobalFlag())
620 {
621 #ifdef TIMING
622 gettimeofday(&tvis, NULL);
623 #endif
624
625 #ifdef DUMP_WT
626 if(global_node_id_ == 0)
627 if(current_epoch_ == 30 || current_epoch_ == 60 || current_epoch_ == 80)
628 if(current_batch_ == num_train_batches_-1)
629 checkpoint(wTList_, DIFF);
630 #endif
631
632 #ifdef USE_MLSL
633 waitForComms("WEIGHT");
634 waitForComms("BIAS");
635 waitForComms("COMBO");
636 #endif
637
638 #ifdef MLSL
639 data_parallelism->Barrier(MLSL::GT_DATA);
640 #endif
641
642 #if 0
643 solver_->applyUpdate((float**)weight_buf_, (float**)winc_buf_, wdiff_buf_, total_weights_, (float**)wt_lr_mult_, (float**)wt_decay_mult_, "WEIGHT");
644 #else
645 solver_->applyUpdate((float**)weight_buf_, (float**)winc_buf_, wdiff_buf_, total_weights_, 1.0, 1.0, "WEIGHT");
646 #endif
647 if(data_type_ == BF16)
648 convert_f32_bf16((float**)weight_buf_, (libxsmm_bfloat16**)lpweight_buf_, total_weights_);
649
650 #if 0
651 solver_->applyUpdate((float**)bias_buf_, (float**)biinc_buf_, bidiff_buf_, total_biases_, (float**)bias_lr_mult_, (float**)bias_decay_mult_, "BIAS");
652 #else
653 #if 1
654 solver_->applyUpdate((float**)bias_buf_, (float**)biinc_buf_, bidiff_buf_, total_biases_, 1.0, 0.0, "BIAS");
655 #else
656 solver_->applyUpdate((float*)bias_buf_, (float*)biinc_buf_, bidiff_buf_, total_biases_, 1.0, 0.0, "BIAS");
657 #endif
658 #endif
659
660 #ifdef TIMING
661 gettimeofday(&tvie, NULL);
662 double sgdtime = (tvie.tv_sec + tvie.tv_usec*1e-6) - (tvis.tv_sec + tvis.tv_usec*1e-6);
663 printf("global sgd time: %f ms\n",sgdtime*1000);
664 #endif
665 }
666
667 gettimeofday(&tve, NULL);
668 fbtime = (tve.tv_sec + tve.tv_usec*1e-6) - (tvs.tv_sec + tvs.tv_usec*1e-6);
669 if(global_node_id_ == 0 && current_batch_ % 100 == 0)
670 printf("Fwd-Bwd time: %f ms\n",fbtime*1000);
671
672 if ( current_batch_ > 1 )
673 runtime += fbtime;
674
675 #ifdef CANARY_CHECK
676 canary_check(input_buf_, input_can_ptr, ic);
677 canary_check(fact_buf_, fact_can_ptr, fac);
678 canary_check(bact_buf_, bact_can_ptr, bac);
679 #endif
680 }
681
682 current_batch_ = 0;
683
684 if ( num_train_batches_ > 1 ) {
685 char hostname[HOST_NAME_MAX + 1];
686 gethostname(hostname, HOST_NAME_MAX + 1);
687 printf("%s; Average Training time = %f seconds", hostname, runtime/((double)(num_train_batches_-2)));
688 if(runtime > 0) {
689 printf("; Average Training throughput = %f images/s\n", ((double)(batch_size_*(num_train_batches_-2)))/runtime);
690 } else {
691 printf("\n");
692 }
693 }
694
695 // Checkpoint weights and biases
696 if(global_node_id_ == 0)
697 {
698 checkpoint(wTList_, DATA);
699 checkpoint(wTList_, HISTORY);
700 checkpoint(biasTList_, DATA);
701 checkpoint(biasTList_, HISTORY);
702 checkpoint(statsTList_, DATA);
703
704 #ifdef DUMP_ACT_DATA
705 if(current_epoch_ == 30 || current_epoch_ == 60 || current_epoch_ == 80)
706 {
707 checkpoint(outTList_, DATA);
708 checkpoint(outTList_, DIFF);
709 }
710 #endif
711
712 FILE* f = fopen("checkpoint", "w");
713 if(f != NULL)
714 {
715 fprintf(f, "%d %10g %10g\n",current_epoch_, lr_, scf_);
716 fclose(f);
717 }
718 }
719 #ifdef USE_MLSL
720 data_parallelism->Barrier(MLSL::GT_DATA);
721 #endif
722
723 // Tell data node that it should use test data
724 exec_mode_ = VAL;
725
726 if(global_node_id_ == 0)
727 {
728 printf("===========================================\n");
729 printf("VAL mode, testing batches %d\n", num_test_batches_);
730 printf("===========================================\n");
731 }
732
733 // Run validation network at end of each epoch
734 for(; current_batch_<num_test_batches_; current_batch_++)
735 {
736 for(int v=0; v<num_test_views_; v++)
737 for(auto it = etg_[VAL].begin(); it != etg_[VAL].end(); it++)
738 (*it)->invoke();
739 }
740
741 current_batch_ = 0;
742
743 #ifdef CANARY_CHECK
744 canary_check(input_buf_, input_can_ptr, ic);
745 canary_check(fact_buf_, fact_can_ptr, fac);
746 canary_check(weight_buf_, wt_can_ptr, wtc);
747 canary_check(bias_buf_, bias_can_ptr, bic);
748 #endif
749 }
750
751 #ifdef USE_MLSL
752 MLSL::Environment::GetEnv().Free(input_buf_);
753 MLSL::Environment::GetEnv().Free(fact_buf_);
754 MLSL::Environment::GetEnv().Free(bact_buf_);
755 #else
756 libxsmm_free(input_buf_);
757 libxsmm_free(fact_buf_);
758 libxsmm_free(bact_buf_);
759 #endif
760
761 for(int n=0; n<NUM_NUMA_NODES; n++)
762 {
763 #ifdef USE_MLSL
764 MLSL::Environment::GetEnv().Free(weight_buf_[n]);
765 if(lpweight_buf_[n] != NULL)
766 MLSL::Environment::GetEnv().Free(lpweight_buf_[n]);
767 MLSL::Environment::GetEnv().Free(wdiff_buf_[n]);
768 if(lpwdiff_buf_[n] != NULL)
769 MLSL::Environment::GetEnv().Free(lpwdiff_buf_[n]);
770 MLSL::Environment::GetEnv().Free(winc_buf_[n]);
771 #if 1
772 MLSL::Environment::GetEnv().Free(bias_buf_[n]);
773 MLSL::Environment::GetEnv().Free(bidiff_buf_[n]);
774 MLSL::Environment::GetEnv().Free(biinc_buf_[n]);
775 MLSL::Environment::GetEnv().Free(stats_buf_[n]);
776 #else
777 MLSL::Environment::GetEnv().Free(bias_buf_);
778 MLSL::Environment::GetEnv().Free(bidiff_buf_);
779 MLSL::Environment::GetEnv().Free(biinc_buf_);
780 MLSL::Environment::GetEnv().Free(stats_buf_);
781 #endif
782 #else
783 libxsmm_free(weight_buf_[n]);
784 libxsmm_free(wdiff_buf_[n]);
785 if(lpweight_buf_[n] != NULL)
786 libxsmm_free(lpweight_buf_[n]);
787 if(lpwdiff_buf_[n] != NULL)
788 libxsmm_free(lpwdiff_buf_[n]);
789 libxsmm_free(winc_buf_[n]);
790 #if 1
791 libxsmm_free(bias_buf_[n]);
792 libxsmm_free(bidiff_buf_[n]);
793 libxsmm_free(biinc_buf_[n]);
794 libxsmm_free(stats_buf_[n]);
795 #else
796 libxsmm_free(bias_buf_);
797 libxsmm_free(bidiff_buf_);
798 libxsmm_free(biinc_buf_);
799 libxsmm_free(stats_buf_);
800 #endif
801 #endif
802 }
803 }
804 else if(mode == TEST)
805 {
806 exec_mode_ = TEST;
807
808 FILE *f = fopen("checkpoint", "r");
809 fscanf(f, "%d %f %f\n",¤t_epoch_, &lr_, &scf_);
810 fclose(f);
811
812 printf("====================================================================\n");
813 printf("TEST mode, testing batches %d, scaling factor %.10f\n", num_test_batches_, scf_);
814 printf("====================================================================\n");
815
816 load_checkpoint(wTList_, DATA, checkpoint_format_);
817 load_checkpoint(biasTList_, DATA, checkpoint_format_);
818 load_checkpoint(statsTList_, DATA, checkpoint_format_);
819
820 #ifdef _OPENMP
821 #pragma omp parallel
822 #endif
823 {
824 int tid = omp_get_thread_num();
825 int ntps = num_threads_/NUM_NUMA_NODES;
826 int n = tid/ntps;
827 int w = total_weights_;
828 int b = total_biases_;
829
830 if(n != 0 && tid % ntps == 0)
831 {
832 float *wptr = (float*)weight_buf_[n];
833 #if 1
834 float *bptr = (float*)bias_buf_[n];
835 float *sptr = (float*)stats_buf_[n];
836 #endif
837
838 #pragma omp simd
839 for(int i=0; i<w; i++)
840 wptr[i] = ((float*)weight_buf_[0])[i];
841
842 #if 1
843 #pragma omp simd
844 for(int i=0; i<b; i++)
845 {
846 bptr[i] = ((float*)bias_buf_[0])[i];
847 sptr[i] = ((float*)stats_buf_[0])[i];
848 }
849 #endif
850 if(lpweight_buf_[0] != NULL)
851 {
852 libxsmm_bfloat16 *lwptr = (libxsmm_bfloat16*)lpweight_buf_[n];
853 #pragma omp simd
854 for(int i=0; i<w; i++)
855 lwptr[i] = ((libxsmm_bfloat16*)lpweight_buf_[0])[i];
856 }
857 }
858 }
859
860 // Run test network when command-line mode is set to "test"
861 for(int b=0; b<num_test_batches_; b++)
862 {
863 for(auto it = etg_[TEST].begin(); it != etg_[TEST].end(); it++)
864 (*it)->invoke();
865 }
866 }
867 }
868
convert_f32_bf16(float * in,libxsmm_bfloat16 * out,int len,int numa_node)869 void MLEngine::convert_f32_bf16(float* in, libxsmm_bfloat16* out, int len, int numa_node)
870 {
871
872 #ifdef _OPENMP
873 #pragma omp parallel
874 #endif
875 {
876 int tid = omp_get_thread_num();
877 int ntps = num_threads_/NUM_NUMA_NODES;
878 int n = tid/ntps;
879 int ltid = tid - numa_node*ntps;
880
881 if(n == numa_node)
882 {
883 int jobs = (len % ntps == 0) ? len/ntps : len/ntps + 1;
884 int tb = (ltid*jobs < len) ? ltid*jobs : len;
885 int te = ((ltid+1)*jobs < len) ? (ltid+1)*jobs : len;
886
887 for (int i = tb; i < te; i+=16 ) {
888 __m512 vfp32 = gxm_fp32_to_bfp16_rne_adjustment_avx512f( _mm512_loadu_ps( in+i ) );
889 __m256i vbfp16 = gxm_fp32_to_bfp16_truncate_avx512f( vfp32 );
890 _mm256_storeu_si256( (__m256i*)(out+i), vbfp16 );
891 }
892 }
893 }
894 }
895
convert_f32_bf16(float ** in,libxsmm_bfloat16 ** out,int len)896 void MLEngine::convert_f32_bf16(float** in, libxsmm_bfloat16** out, int len)
897 {
898 #ifdef _OPENMP
899 #pragma omp parallel
900 #endif
901 {
902 int tid = omp_get_thread_num();
903 int ntps = num_threads_/NUM_NUMA_NODES;
904 int n = tid/ntps;
905 int ltid = tid - n*ntps;
906
907 float *inp = in[n];
908 libxsmm_bfloat16 *outp = out[n];
909
910 int jobs = (len % ntps == 0) ? len/ntps : len/ntps + 1;
911 int tb = (ltid*jobs < len) ? ltid*jobs : len;
912 int te = ((ltid+1)*jobs < len) ? (ltid+1)*jobs : len;
913
914 for (int i = tb; i < te; i+=16 ) {
915 __m512 vfp32 = gxm_fp32_to_bfp16_rne_adjustment_avx512f(_mm512_loadu_ps(inp + i));
916 __m256i vbfp16 = gxm_fp32_to_bfp16_truncate_avx512f(vfp32);
917 _mm256_storeu_si256( (__m256i*)(outp+i), vbfp16 );
918 }
919 }
920 }
convert_bf16_f32(libxsmm_bfloat16 * in,float * out,int len)921 void MLEngine::convert_bf16_f32(libxsmm_bfloat16* in, float* out, int len)
922 {
923 int i;
924
925 #ifdef _OPENMP
926 #pragma omp parallel for private(i)
927 #endif
928 for ( i = 0; i < len; i+=16 ) {
929 __m256i vbfp16 = _mm256_loadu_si256( (const __m256i*)(in+i) );
930 __m512 vfp32 = gxm_bfp16_to_fp32_avx512f( vbfp16 );
931 _mm512_storeu_ps( out+i, vfp32 );
932 }
933 }
934
allocate_memory(string tenType,TensorList L,int buftype,vector<int> & can_ptr,int * nc,long long int * bufsize)935 void MLEngine::allocate_memory(string tenType, TensorList L, int buftype, vector<int>& can_ptr, int* nc, long long int* bufsize)
936 {
937 bool ttp = false; //(tenType != "WEIGHT") & (tenType != "BIAS");
938
939 long long int s = ttp ? START_GUARD_BAND : 0;
940 TensorBuf* tBuf;
941 int num_canaries = 0;
942
943 float* lrptr, *decptr;
944
945 // Get total buffer size required for tensors of type buftype
946 for(Iter it=L.begin(); it != L.end(); it++)
947 {
948 Tensor* t = it->t;
949
950 bool found = false;
951 for(int i=0; i<t->getNumDataBuffers(); i++)
952 {
953 tBuf = t->getBuf(i);
954 if(tBuf->getBufferType() == buftype)
955 {
956 found = true;
957 break;
958 }
959 }
960 if(!found) continue;
961
962 long long int size = tBuf->getBufferSize();
963 if(size > 0)
964 {
965 if(global_node_id_ == 0)
966 {
967 printf("Tensor %s needs %lld bytes for buffer %d\n", t->getTensorName().c_str(), size, buftype);
968 fflush(stdout);
969 }
970 s += size;
971 if(ttp)
972 s += END_GUARD_BAND;
973
974 if(ttp)
975 num_canaries++;
976 }
977 }
978
979 if(tenType == "WEIGHT")
980 total_weights_ = s/sizeof(float);
981 else if(tenType == "BIAS" || tenType == "STATS")
982 total_biases_ = s/sizeof(float);
983
984 if(solver_->getGlobalFlag())
985 {
986 if(tenType == "WEIGHT")
987 {
988 #ifdef BF16_MLSL
989 if(buftype == DIFF)
990 {
991 if(data_type_ == FLOAT)
992 total_weights_ = s/sizeof(float);
993 else if(data_type_ == BF16)
994 total_weights_ = s/sizeof(libxsmm_bfloat16);
995 }
996 else
997 #endif
998 total_weights_ = s/sizeof(float);
999
1000 int factor = num_threads_ * VLEN;
1001 int nwt = (total_weights_ + factor - 1)/factor;
1002 total_weights_ = nwt * factor;
1003
1004 #ifdef BF16_MLSL
1005 if(buftype == DIFF)
1006 {
1007 if(data_type_ == FLOAT)
1008 s = total_weights_ * sizeof(float);
1009 else if(data_type_ == BF16)
1010 s = total_weights_ * sizeof(libxsmm_bfloat16);
1011 }
1012 else
1013 #endif
1014 s = total_weights_ * sizeof(float);
1015 }
1016 else if(tenType == "BIAS" || tenType == "STATS")
1017 {
1018 total_biases_ = s / sizeof(float);
1019 int factor = num_threads_ * VLEN;
1020 int nwt = (total_biases_ + factor - 1)/factor;
1021 total_biases_ = nwt * factor;
1022
1023 s = total_biases_ * sizeof(float);
1024 }
1025 }
1026
1027 // Number of guard bands in tensor; used for canary checking
1028 *nc = num_canaries;
1029
1030 // Allocate memory
1031 #ifdef BF16_MLSL
1032 bool lp = (data_type_ == BF16) && (tenType=="WEIGHT") && (buftype == DATA);
1033 #else
1034 bool lp = (data_type_ == BF16) && (tenType=="WEIGHT");
1035 #endif
1036
1037 void *buf_;
1038 void **ptrptr, **lptrptr=NULL;
1039
1040 #if 0 //def USE_MLSL
1041 s = ALIGN_SIZE(s, 2097152);
1042 #endif
1043
1044 if(tenType=="INPUT")
1045 {
1046 #ifdef USE_MLSL
1047 buf_ = (void*)MLSL::Environment::GetEnv().Alloc(s, 2097152);
1048 #else
1049 buf_ = (void*)libxsmm_aligned_malloc(s, 2097152);
1050 #endif
1051 input_buf_ = buf_;
1052 }
1053 else if(tenType == "FACT")
1054 {
1055 #ifdef USE_MLSL
1056 buf_ = (void*)MLSL::Environment::GetEnv().Alloc(s, 2097152);
1057 #else
1058 buf_ = (void*)libxsmm_aligned_malloc(s, 2097152);
1059 #endif
1060 fact_buf_ = buf_;
1061 }
1062 else if(tenType == "WEIGHT")
1063 {
1064 if(buftype == DATA)
1065 {
1066 for(int n=0; n<NUM_NUMA_NODES; n++)
1067 {
1068 #ifdef USE_MLSL
1069 weight_buf_[n] = (void*)MLSL::Environment::GetEnv().Alloc(s, 2097152);
1070 #else
1071 weight_buf_[n] = (void*)libxsmm_aligned_malloc(s, 2097152);
1072 #endif
1073 if(lp)
1074 #ifdef USE_MLSL
1075 lpweight_buf_[n] = (void*)MLSL::Environment::GetEnv().Alloc(s/sizeof(libxsmm_bfloat16), 2097152);
1076 #else
1077 lpweight_buf_[n] = (void*)libxsmm_aligned_malloc(s/sizeof(libxsmm_bfloat16), 2097152);
1078 #endif
1079 }
1080 buf_ = weight_buf_[0];
1081 ptrptr = weight_buf_;
1082 if(lp)
1083 lptrptr = lpweight_buf_;
1084 }
1085 else if(buftype == DIFF)
1086 {
1087 for(int n=0; n<NUM_NUMA_NODES; n++)
1088 {
1089 #ifdef USE_MLSL
1090 wdiff_buf_[n] = (void*)MLSL::Environment::GetEnv().Alloc(s, 2097152);
1091 #else
1092 wdiff_buf_[n] = (void*)libxsmm_aligned_malloc(s, 2097152);
1093 #endif
1094 if(lp)
1095 #ifdef USE_MLSL
1096 lpwdiff_buf_[n] = (void*)MLSL::Environment::GetEnv().Alloc(s/sizeof(libxsmm_bfloat16), 2097152);
1097 #else
1098 lpwdiff_buf_[n] = (void*)libxsmm_aligned_malloc(s/sizeof(libxsmm_bfloat16), 2097152);
1099 #endif
1100 }
1101 buf_ = wdiff_buf_[0];
1102 ptrptr = wdiff_buf_;
1103 if(lp)
1104 lptrptr = lpwdiff_buf_;
1105 }
1106 else if(buftype == HISTORY)
1107 {
1108 for(int n=0; n<NUM_NUMA_NODES; n++)
1109 #ifdef USE_MLSL
1110 winc_buf_[n] = (void*)MLSL::Environment::GetEnv().Alloc(s, 2097152);
1111 #else
1112 winc_buf_[n] = (void*)libxsmm_aligned_malloc(s, 2097152);
1113 #endif
1114 buf_ = winc_buf_[0];
1115 ptrptr = winc_buf_;
1116 }
1117 }
1118 else if(tenType == "BACT")
1119 {
1120 #ifdef USE_MLSL
1121 buf_ = (void*)MLSL::Environment::GetEnv().Alloc(s, 2097152);
1122 #else
1123 buf_ = (void*)libxsmm_aligned_malloc(s, 2097152);
1124 #endif
1125 bact_buf_ = buf_;
1126 }
1127 else if(tenType == "STATS")
1128 {
1129 #if 1
1130 for(int n=0; n<NUM_NUMA_NODES; n++)
1131 #ifdef USE_MLSL
1132 stats_buf_[n] = (void*)MLSL::Environment::GetEnv().Alloc(s, 2097152);
1133 #else
1134 stats_buf_[n] = (void*)libxsmm_aligned_malloc(s, 2097152);
1135 #endif
1136 buf_ = stats_buf_[0];
1137 ptrptr = stats_buf_;
1138 #else
1139 #ifdef USE_MLSL
1140 stats_buf_ = (void*)MLSL::Environment::GetEnv().Alloc(s, 2097152);
1141 #else
1142 stats_buf_ = (void*)libxsmm_aligned_malloc(s, 2097152);
1143 #endif
1144 buf_ = stats_buf_;
1145 #endif
1146 }
1147 else if(tenType == "BIAS")
1148 {
1149 if(buftype == DATA)
1150 {
1151 #if 1
1152 for(int n=0; n<NUM_NUMA_NODES; n++)
1153 #ifdef USE_MLSL
1154 bias_buf_[n] = (void*)MLSL::Environment::GetEnv().Alloc(s, 2097152);
1155 #else
1156 bias_buf_[n] = (void*)libxsmm_aligned_malloc(s, 2097152);
1157 #endif
1158 buf_ = bias_buf_[0];
1159 ptrptr = bias_buf_;
1160 #else
1161 #ifdef USE_MLSL
1162 bias_buf_ = (void*)MLSL::Environment::GetEnv().Alloc(s, 2097152);
1163 #else
1164 bias_buf_ = (void*)libxsmm_aligned_malloc(s, 2097152);
1165 #endif
1166 buf_ = bias_buf_;
1167 #endif
1168 }
1169 else if(buftype == DIFF)
1170 {
1171 #if 1
1172 for(int n=0; n<NUM_NUMA_NODES; n++)
1173 #ifdef USE_MLSL
1174 bidiff_buf_[n] = (void*)MLSL::Environment::GetEnv().Alloc(s, 2097152);
1175 #else
1176 bidiff_buf_[n] = (void*)libxsmm_aligned_malloc(s, 2097152);
1177 #endif
1178 buf_ = bidiff_buf_[0];
1179 ptrptr = bidiff_buf_;
1180 #else
1181 #ifdef USE_MLSL
1182 bidiff_buf_ = (void*)MLSL::Environment::GetEnv().Alloc(s, 2097152);
1183 #else
1184 bidiff_buf_ = (void*)libxsmm_aligned_malloc(s, 2097152);
1185 #endif
1186 buf_ = bidiff_buf_;
1187 #endif
1188 }
1189 else if(buftype == HISTORY)
1190 {
1191 #if 1
1192 for(int n=0; n<NUM_NUMA_NODES; n++)
1193 #ifdef USE_MLSL
1194 biinc_buf_[n] = (void*)MLSL::Environment::GetEnv().Alloc(s, 2097152);
1195 #else
1196 biinc_buf_[n] = (void*)libxsmm_aligned_malloc(s, 2097152);
1197 #endif
1198 buf_ = biinc_buf_[0];
1199 ptrptr = biinc_buf_;
1200 #else
1201 #ifdef USE_MLSL
1202 biinc_buf_ = (void*)MLSL::Environment::GetEnv().Alloc(s, 2097152);
1203 #else
1204 biinc_buf_ = (void*)libxsmm_aligned_malloc(s, 2097152);
1205 #endif
1206 buf_ = biinc_buf_;
1207 #endif
1208 }
1209 }
1210
1211 // Total buffer size, including guard bands before and after each buffer (currntly 64 bytes long)
1212 *bufsize = s + (lp ? s/sizeof(libxsmm_bfloat16) : 0);
1213
1214 #if 0
1215 printf("Tensor with buffers %d @ %p with total size %lld\n",buftype, buf_, s);
1216 fflush(stdout);
1217
1218 if(buf_ != NULL)
1219 {
1220 #ifndef USE_NUMA
1221 memset(buf_, 0, s);
1222 #endif
1223 }
1224 else {
1225 printf("could not allocate tensor memory.. exiting\n");
1226 exit(-1);
1227 }
1228
1229 if(lp && lpweight_buf_==NULL)
1230 {
1231 printf("could not allocate low precision weights memory.. exiting\n");
1232 exit(-1);
1233 }
1234
1235 if(solver_->getGlobalFlag())
1236 {
1237 if(tenType == "WEIGHT" && buftype == DIFF)
1238 {
1239 for(int n=0; n<NUM_NUMA_NODES; n++)
1240 {
1241 wt_lr_mult_[n] = (float*)libxsmm_aligned_malloc(total_weights_*sizeof(float), 2097152);
1242 if(wt_lr_mult_[n] != NULL)
1243 {
1244 float *ptr = wt_lr_mult_[n];
1245
1246 #ifdef _OPENMP
1247 #pragma omp parallel
1248 #endif
1249 {
1250 int tid = omp_get_thread_num();
1251 int ntps = num_threads_/NUM_NUMA_NODES;
1252 int s = tid/ntps;
1253 if(s == n && tid % ntps == 0)
1254 for(int i=0; i<total_weights_; i++)
1255 ptr[i] = 0.0;
1256 }
1257 }
1258
1259 wt_decay_mult_[n] = (float*)libxsmm_aligned_malloc(total_weights_*sizeof(float), 2097152);
1260 if(wt_decay_mult_[n] != NULL)
1261 {
1262 float *ptr = wt_decay_mult_[n];
1263
1264 #ifdef _OPENMP
1265 #pragma omp parallel
1266 #endif
1267 {
1268 int tid = omp_get_thread_num();
1269 int ntps = num_threads_/NUM_NUMA_NODES;
1270 int s = tid/ntps;
1271 if(s == n && tid % ntps == 0)
1272 for(int i=0; i<total_weights_; i++)
1273 ptr[i] = 0.0;
1274 }
1275 }
1276 }
1277 lrptr = wt_lr_mult_[0];
1278 decptr = wt_decay_mult_[0];
1279 }
1280 else if(tenType == "BIAS" && buftype == DIFF)
1281 {
1282 for(int n=0; n<NUM_NUMA_NODES; n++)
1283 {
1284 bias_lr_mult_[n] = (float*)_mm_malloc(total_biases_*sizeof(float), 64);
1285 if(bias_lr_mult_[n] != NULL)
1286 {
1287 float *ptr = bias_lr_mult_[n];
1288
1289 #ifdef _OPENMP
1290 #pragma omp parallel
1291 #endif
1292 {
1293 int tid = omp_get_thread_num();
1294 int ntps = num_threads_/NUM_NUMA_NODES;
1295 int s = tid/ntps;
1296 if(s == n && tid % ntps == 0)
1297 for(int i=0; i<total_biases_; i++)
1298 ptr[i] = 0.0;
1299 }
1300 }
1301
1302 bias_decay_mult_[n] = (float*)_mm_malloc(total_biases_*sizeof(float), 64);
1303 if(bias_decay_mult_[n] != NULL)
1304 {
1305 float *ptr = bias_decay_mult_[n];
1306
1307 #ifdef _OPENMP
1308 #pragma omp parallel
1309 #endif
1310 {
1311 int tid = omp_get_thread_num();
1312 int ntps = num_threads_/NUM_NUMA_NODES;
1313 int s = tid/ntps;
1314 if(s == n && tid % ntps == 0)
1315 for(int i=0; i<total_biases_; i++)
1316 ptr[i] = 0.0;
1317 }
1318 }
1319 }
1320 lrptr = bias_lr_mult_[0];
1321 decptr = bias_decay_mult_[0];
1322 }
1323 }
1324 #endif
1325
1326 if(ttp)
1327 memset(buf_, CANARY, START_GUARD_BAND);
1328
1329 long long int bytes=0, lpbytes=0;
1330 int offset=0, bias_offset=0;
1331
1332 //Set up tensor buffer pointers
1333 void* ptr = ttp ? buf_ + START_GUARD_BAND : buf_;
1334 void* lptr = lp ? lpweight_buf_[0] : NULL;
1335 void* lgptr = lp ? lpwdiff_buf_[0] : NULL;
1336
1337 for(Iter it=L.begin(); it != L.end(); it++)
1338 {
1339 Tensor* t = it->t;
1340
1341 bool found = false;
1342 for(int i=0; i<t->getNumDataBuffers(); i++)
1343 {
1344 tBuf = t->getBuf(i);
1345 if(tBuf->getBufferType() == buftype)
1346 {
1347 found = true;
1348 break;
1349 }
1350 }
1351 if(!found) continue;
1352
1353 // Don't process Split nodes further for forward activations
1354 string nntype = dynamic_cast<NNNode*>(t->getOwner())->getNodeType();
1355 if(nntype.find("Split") != nntype.npos && buftype == DATA)
1356 continue;
1357
1358 // Scrub or initialize buffers appropriately
1359 bytes = tBuf->getBufferSize();
1360 assert(ptr+bytes <= buf_+s);
1361
1362 lpbytes = lp ? bytes/sizeof(libxsmm_bfloat16) : 0;
1363
1364 #ifndef USE_NUMA
1365 if(t->getType() == INPUT || t->getType() == ACT)
1366 {
1367 if(bytes > 0)
1368 memset(ptr, 0, bytes);
1369 }
1370 #endif
1371
1372 int dtype = tBuf->getDataType();
1373
1374 // Set each node's tensor buffer pointers to the appropritate location in the global buffer
1375 if(tenType == "WEIGHT" || tenType == "BIAS" || tenType == "STATS")
1376 {
1377 if(buftype == DATA || buftype == DIFF)
1378 {
1379 tBuf->setBufferPtr(ptrptr);
1380 tBuf->setOffset(offset);
1381 }
1382 tBuf->setBuffer(ptr);
1383
1384 if(lp)
1385 {
1386 if(buftype == DATA)
1387 tBuf->setLPBuffer(lptr);
1388 else if(buftype == DIFF)
1389 tBuf->setLPBuffer(lgptr);
1390 tBuf->setLPBufferPtr(lptrptr);
1391 }
1392 }
1393 else
1394 tBuf->setBuffer(ptr);
1395
1396 // If weight or bias tensor, call corresponding intialization function (for training only)
1397 if(!is_inference_only())
1398 {
1399 int tType = t->getType();
1400 if(tType == CONVWEIGHT)
1401 {
1402 if(nntype == "FusedConvBN")
1403 {
1404 FusedConvBNNode *fcbn = dynamic_cast<FusedConvBNNode*>(t->getOwner());
1405 assert(bytes > 0);
1406 if(!load_from_checkpoint_)
1407 {
1408 fcbn->fillWeightBuffers(tBuf, buftype, bytes);
1409 #if 0
1410 if(lp)
1411 convert_f32_bf16((float*)ptr, (libxsmm_bfloat16*)lptr, lpbytes/sizeof(libxsmm_bfloat16), 0);
1412 #endif
1413 }
1414 #if 0
1415 if(solver_->getGlobalFlag())
1416 if(buftype == DIFF)
1417 if(data_type_ == FLOAT)
1418 fcbn->fillWeightMultipliers(lrptr, decptr, bytes/sizeof(float));
1419 else if(data_type_ == BF16)
1420 fcbn->fillWeightMultipliers(lrptr, decptr, bytes/sizeof(libxsmm_bfloat16));
1421 #endif
1422 }
1423 else if(nntype == "Convolution")
1424 {
1425 ConvNode* cn = dynamic_cast<ConvNode*>(t->getOwner());
1426 assert(bytes > 0);
1427 if(!load_from_checkpoint_)
1428 {
1429 cn->fillWeightBuffers(tBuf, buftype, bytes);
1430 #if 0
1431 if(lp)
1432 convert_f32_bf16((float*)ptr, (libxsmm_bfloat16*)lptr, lpbytes/sizeof(libxsmm_bfloat16), 0);
1433 #endif
1434 }
1435
1436 #if 0
1437 if(solver_->getGlobalFlag())
1438 if(buftype == DIFF)
1439 if(data_type_ == FLOAT)
1440 cn->fillWeightMultipliers(lrptr, decptr, bytes/sizeof(float));
1441 else if(data_type_ == BF16)
1442 cn->fillWeightMultipliers(lrptr, decptr, bytes/sizeof(libxsmm_bfloat16));
1443 #endif
1444 }
1445 }
1446 else if(tType == CONVBIAS)
1447 {
1448 ConvNode* cn = dynamic_cast<ConvNode*>(t->getOwner());
1449 assert(bytes > 0);
1450 if(!load_from_checkpoint_)
1451 cn->fillBiasBuffers(tBuf, buftype, bytes);
1452 #if 0
1453 if(solver_->getGlobalFlag())
1454 if(buftype == DIFF)
1455 cn->fillBiasMultipliers(lrptr, decptr, bytes/sizeof(float));
1456 #endif
1457 }
1458 else if(tType == FCWEIGHT)
1459 {
1460 FCNode* fn = dynamic_cast<FCNode*>(t->getOwner());
1461 assert(bytes > 0);
1462 if(!load_from_checkpoint_)
1463 {
1464 fn->fillWeightBuffers(tBuf, buftype, bytes);
1465 #if 0
1466 if(lp)
1467 convert_f32_bf16((float*)ptr, (libxsmm_bfloat16*)lptr, lpbytes/sizeof(libxsmm_bfloat16), 0);
1468 #endif
1469 }
1470
1471 #if 0
1472 if(solver_->getGlobalFlag())
1473 if(buftype == DIFF)
1474 if(data_type_ == FLOAT)
1475 fn->fillWeightMultipliers(lrptr, decptr, bytes/sizeof(float));
1476 else if(data_type_ == BF16)
1477 fn->fillWeightMultipliers(lrptr, decptr, bytes/sizeof(libxsmm_bfloat16));
1478 #endif
1479 }
1480 else if(tType == FCBIAS)
1481 {
1482 FCNode* fn = dynamic_cast<FCNode*>(t->getOwner());
1483 assert(bytes > 0);
1484 if(!load_from_checkpoint_)
1485 fn->fillBiasBuffers(tBuf, buftype, bytes);
1486 #if 0
1487 if(solver_->getGlobalFlag())
1488 if(buftype == DIFF)
1489 fn->fillBiasMultipliers(lrptr, decptr, bytes/sizeof(float));
1490 #endif
1491 }
1492 else if((tType == BNORMSCALE) || (tType == BNORMSHIFT))
1493 {
1494 if(nntype == "FusedConvBN")
1495 {
1496 FusedConvBNNode *fcbn = dynamic_cast<FusedConvBNNode*>(t->getOwner());
1497 assert(bytes > 0);
1498 if(!load_from_checkpoint_)
1499 fcbn->fillBuffer(tBuf, buftype, bytes);
1500 #if 0
1501 if(solver_->getGlobalFlag())
1502 if(buftype == DIFF)
1503 fcbn->fillBiasMultipliers(lrptr, decptr, bytes/sizeof(float));
1504 #endif
1505 }
1506 else if(nntype == "FusedBatchNorm")
1507 {
1508 FusedBNormNode* bn = dynamic_cast<FusedBNormNode*>(t->getOwner());
1509 assert(bytes > 0);
1510 if(!load_from_checkpoint_)
1511 bn->fillBuffer(tBuf, buftype, bytes);
1512 #if 0
1513 if(solver_->getGlobalFlag())
1514 if(buftype == DIFF)
1515 bn->fillBiasMultipliers(lrptr, decptr, bytes/sizeof(float));
1516 #endif
1517 }
1518 }
1519 else if((tType == BNORMMEAN) || (tType == BNORMVAR))
1520 {
1521 if(nntype == "FusedConvBN")
1522 {
1523 FusedConvBNNode *fcbn = dynamic_cast<FusedConvBNNode*>(t->getOwner());
1524 assert(bytes > 0);
1525 if(!load_from_checkpoint_)
1526 fcbn->fillBuffer(tBuf, buftype, bytes);
1527 }
1528 else if(nntype == "FusedBatchNorm")
1529 {
1530 FusedBNormNode* bn = dynamic_cast<FusedBNormNode*>(t->getOwner());
1531 assert(bytes > 0);
1532 if(!load_from_checkpoint_)
1533 bn->fillBuffer(tBuf, buftype, bytes);
1534 }
1535 }
1536 }
1537
1538 if(bytes > 0)
1539 {
1540 ptr += bytes;
1541 if(lp)
1542 if(buftype == DATA)
1543 lptr += lpbytes;
1544 #ifndef BF16_MLSL
1545 else if(buftype == DIFF)
1546 lgptr += lpbytes;
1547 #endif
1548
1549 #ifdef BF16_MLSL
1550 if(tenType == "WEIGHT" && buftype == DATA)
1551 offset += bytes/sizeof(float);
1552 else if(tenType == "WEIGHT" && buftype == DIFF)
1553 {
1554 if(data_type_ == FLOAT)
1555 offset += bytes/sizeof(float);
1556 else if(data_type_ == BF16)
1557 offset += bytes/sizeof(libxsmm_bfloat16);
1558 }
1559 #else
1560 if(tenType == "WEIGHT")
1561 offset += bytes/sizeof(float);
1562 #endif
1563 else if((tenType == "BIAS" && (buftype == DATA || buftype == DIFF)) || tenType == "STATS")
1564 offset += bytes/sizeof(float);
1565
1566 #if 0
1567 if(solver_->getGlobalFlag())
1568 {
1569 if(tenType == "WEIGHT" && buftype == DIFF)
1570 {
1571 if(data_type_ == FLOAT)
1572 {
1573 lrptr += bytes/sizeof(float);
1574 decptr += bytes/sizeof(float);
1575 }
1576 else if(data_type_ == BF16)
1577 {
1578 lrptr += bytes/sizeof(libxsmm_bfloat16);
1579 decptr += bytes/sizeof(libxsmm_bfloat16);
1580 }
1581 }
1582 else if(tenType == "BIAS" && buftype == DIFF)
1583 {
1584 lrptr += bytes/sizeof(float);
1585 decptr += bytes/sizeof(float);
1586 }
1587 }
1588 #endif
1589
1590 assert(ptr <= buf_ + s);
1591
1592 // For canary checking
1593 if(ttp)
1594 {
1595 memset(ptr, CANARY, END_GUARD_BAND);
1596 can_ptr.push_back(bytes);
1597 assert(can_ptr.size() <= num_canaries);
1598 }
1599 if(ttp)
1600 ptr += END_GUARD_BAND;
1601 }
1602 assert(ptr <= buf_ + s);
1603 #if 0
1604 printf("ptr @ %p\n",ptr);
1605 #endif
1606 }
1607
1608 if(tenType=="WEIGHT" && buftype==DATA)
1609 {
1610 #ifdef _OPENMP
1611 #pragma omp parallel
1612 #endif
1613 {
1614 int tid = omp_get_thread_num();
1615 int ntps = num_threads_/NUM_NUMA_NODES;
1616 int n = tid/ntps;
1617 int w = total_weights_;
1618 if(n != 0 && tid % ntps == 0)
1619 {
1620 float *wtptr = (float*)weight_buf_[n];
1621
1622 #pragma omp simd
1623 for(int i=0; i<w; i++)
1624 wtptr[i] = ((float*)weight_buf_[0])[i];
1625 }
1626 }
1627
1628 if(lp)
1629 convert_f32_bf16((float**)weight_buf_, (libxsmm_bfloat16**)lpweight_buf_, total_weights_);
1630 }
1631
1632 if(tenType=="WEIGHT" && buftype==DIFF)
1633 {
1634 if(data_type_ == FLOAT)
1635 {
1636 #ifdef _OPENMP
1637 #pragma omp parallel
1638 #endif
1639 {
1640 int tid = omp_get_thread_num();
1641 int ntps = num_threads_/NUM_NUMA_NODES;
1642 int n = tid/ntps;
1643 int w = total_weights_;
1644 if(n != 0 && tid % ntps == 0)
1645 {
1646 float *wdiff = (float*)wdiff_buf_[n];
1647
1648 #pragma omp simd
1649 for(int i=0; i<w; i++)
1650 wdiff[i] = ((float*)wdiff_buf_[0])[i];
1651 }
1652 }
1653 }
1654 else if(data_type_ == BF16)
1655 {
1656 #ifdef _OPENMP
1657 #pragma omp parallel
1658 #endif
1659 {
1660 int tid = omp_get_thread_num();
1661 int ntps = num_threads_/NUM_NUMA_NODES;
1662 int n = tid/ntps;
1663 int w = total_weights_;
1664 if(n != 0 && tid % ntps == 0)
1665 {
1666 libxsmm_bfloat16 *lpwdiff = (libxsmm_bfloat16*)lpwdiff_buf_[n];
1667 float *wdiff = (float*)wdiff_buf_[n];
1668
1669 #pragma omp simd
1670 for(int i=0; i<w; i++)
1671 {
1672 lpwdiff[i] = ((libxsmm_bfloat16*)lpwdiff_buf_[0])[i];
1673 wdiff[i] = ((float*)wdiff_buf_[0])[i];
1674 }
1675 }
1676 }
1677 }
1678
1679 #if 0
1680 #ifdef _OPENMP
1681 #pragma omp parallel
1682 #endif
1683 {
1684 int tid = omp_get_thread_num();
1685 int ntps = num_threads_/NUM_NUMA_NODES;
1686 int n = tid/ntps;
1687 int w = total_weights_;
1688 if(n != 0 && tid % ntps == 0)
1689 {
1690 float *lrp = (float*)wt_lr_mult_[n];
1691 float *dcp = (float*)wt_decay_mult_[n];
1692
1693 #pragma omp simd
1694 for(int i=0; i<w; i++)
1695 {
1696 lrp[i] = ((float*)wt_lr_mult_[0])[i];
1697 dcp[i] = ((float*)wt_decay_mult_[0])[i];
1698 }
1699 }
1700 }
1701 #endif
1702 }
1703
1704 if(tenType=="WEIGHT" && buftype==HISTORY)
1705 {
1706 #ifdef _OPENMP
1707 #pragma omp parallel
1708 #endif
1709 {
1710 int tid = omp_get_thread_num();
1711 int ntps = num_threads_/NUM_NUMA_NODES;
1712 int n = tid/ntps;
1713 int w = total_weights_;
1714 if(n != 0 && tid % ntps == 0)
1715 {
1716 float *inc = (float*)winc_buf_[n];
1717
1718 #pragma omp simd
1719 for(int i=0; i<w; i++)
1720 inc[i] = ((float*)winc_buf_[0])[i];
1721 }
1722 }
1723 }
1724
1725 #if 1
1726 if(tenType == "BIAS" && buftype == DATA)
1727 {
1728 #ifdef _OPENMP
1729 #pragma omp parallel
1730 #endif
1731 {
1732 int tid = omp_get_thread_num();
1733 int ntps = num_threads_/NUM_NUMA_NODES;
1734 int n = tid/ntps;
1735 int b = total_biases_;
1736
1737 if(n != 0 && tid % ntps == 0)
1738 {
1739 float *bias = (float*)bias_buf_[n];
1740
1741 #pragma omp simd
1742 for(int i=0; i<b; i++)
1743 bias[i] = ((float*)bias_buf_[0])[i];
1744 }
1745 }
1746 }
1747
1748
1749 if(tenType == "BIAS" && buftype == DIFF)
1750 {
1751 #ifdef _OPENMP
1752 #pragma omp parallel
1753 #endif
1754 {
1755 int tid = omp_get_thread_num();
1756 int ntps = num_threads_/NUM_NUMA_NODES;
1757 int n = tid/ntps;
1758 int b = total_biases_;
1759
1760 if(n != 0 && tid % ntps == 0)
1761 {
1762 float *bidiff = (float*)bidiff_buf_[n];
1763 #if 0
1764 float *lrp = (float*)bias_lr_mult_[n];
1765 float *dcp = (float*)bias_decay_mult_[n];
1766 #endif
1767
1768 #pragma omp simd
1769 for(int i=0; i<b; i++)
1770 {
1771 bidiff[i] = ((float*)bidiff_buf_[0])[i];
1772 #if 0
1773 lrp[i] = ((float*)bias_lr_mult_[0])[i];
1774 dcp[i] = ((float*)bias_decay_mult_[0])[i];
1775 #endif
1776 }
1777 }
1778 }
1779 }
1780
1781 if(tenType == "BIAS" && buftype == HISTORY)
1782 {
1783 #ifdef _OPENMP
1784 #pragma omp parallel
1785 #endif
1786 {
1787 int tid = omp_get_thread_num();
1788 int ntps = num_threads_/NUM_NUMA_NODES;
1789 int n = tid/ntps;
1790 int b = total_biases_;
1791
1792 if(n != 0 && tid % ntps == 0)
1793 {
1794 float *biinc = (float*)biinc_buf_[n];
1795
1796 #pragma omp simd
1797 for(int i=0; i<b; i++)
1798 biinc[i] = ((float*)biinc_buf_[0])[i];
1799 }
1800 }
1801 }
1802
1803 if(tenType == "STATS")
1804 {
1805 #ifdef _OPENMP
1806 #pragma omp parallel
1807 #endif
1808 {
1809 int tid = omp_get_thread_num();
1810 int ntps = num_threads_/NUM_NUMA_NODES;
1811 int n = tid/ntps;
1812 int b = total_biases_;
1813
1814 if(n != 0 && tid % ntps == 0)
1815 {
1816 float *stats = (float*)stats_buf_[n];
1817
1818 #pragma omp simd
1819 for(int i=0; i<b; i++)
1820 stats[i] = ((float*)stats_buf_[0])[i];
1821 }
1822 }
1823 }
1824 #endif
1825 }
1826
insertSplitNodes(NTGParameter & p,NTGParameter * ps)1827 void MLEngine::insertSplitNodes(NTGParameter& p, NTGParameter* ps)
1828 {
1829 ps->CopyFrom(p);
1830 ps->clear_node();
1831
1832 vector< pair<string, string> > top_names;
1833
1834 for(int i=0; i<p.node_size(); i++)
1835 {
1836 const NodeParameter& np = p.node(i);
1837 string nn = np.name();
1838 for(int j=0; j<np.top_size(); j++)
1839 top_names.push_back(make_pair(np.top(j), nn));
1840 }
1841
1842 std::multimap<std::string, NodeParameter> top_as_bot;
1843
1844 for(int i=0; i < top_names.size(); i++)
1845 {
1846 pair<string, string> tn = top_names[i];
1847 for(int j=0; j < p.node_size(); j++)
1848 {
1849 const NodeParameter& np = p.node(j);
1850 string nn = p.node(j).name();
1851 if(nn.compare(tn.second) == 0) continue;
1852 for(int k=0; k < np.bottom_size(); k++)
1853 {
1854 std::string t = tn.first;
1855 if(t.compare(p.node(j).bottom(k)) == 0)
1856 top_as_bot.insert(make_pair(t, p.node(j)));
1857 }
1858 }
1859 }
1860
1861 std::multimap<std::string, std::string> old_bottom;
1862 std::multimap<std::string, std::string> new_bottom;
1863
1864 for(int i=0; i<p.node_size(); i++)
1865 {
1866 NodeParameter* np = ps->add_node();
1867 np->CopyFrom(p.node(i));
1868 string onn = np->name();
1869
1870 for(int j=0; j<np->top_size(); j++)
1871 {
1872 string t = np->top(j);
1873 int split_count = top_as_bot.count(t);
1874 if(split_count > 1)
1875 {
1876 NodeParameter *snp = ps->add_node();
1877 snp->Clear();
1878 snp->add_bottom(t);
1879 string snn = t + "_" + onn + "_" + std::to_string(j) + "_split";
1880 snp->set_name(snn);
1881 snp->set_type("Split");
1882 if(t.compare("label") == 0)
1883 snp->set_propagate_down(false);
1884
1885 std::multimap<string, NodeParameter>::iterator it;
1886 int k = 0;
1887 for(it=top_as_bot.equal_range(t).first; it != top_as_bot.equal_range(t).second; it++)
1888 {
1889 NodeParameter onp = (*it).second;
1890 string nn = onp.name();
1891
1892 string stn = t + "_" + nn + "_" + std::to_string(j) + "_split_" + std::to_string(k);
1893 snp->add_top(stn);
1894 k++;
1895
1896 for(int l=0; l<onp.bottom_size(); l++)
1897 {
1898 if(onp.bottom(l) == t)
1899 {
1900 old_bottom.insert(make_pair(t, nn));
1901 new_bottom.insert(make_pair(nn, stn));
1902 }
1903 }
1904 }
1905 }
1906 }
1907 }
1908
1909 std::multimap<std::string, std::string>::iterator it1;
1910 std::multimap<std::string, std::string>::iterator it2;
1911 for(int i=0; i<ps->node_size(); i++)
1912 {
1913 NodeParameter* mn = ps->mutable_node(i);
1914 if(mn->type().compare("Split") == 0) continue;
1915 for(int j=0; j<mn->bottom_size(); j++)
1916 {
1917 string t = mn->bottom(j);
1918 it1 = old_bottom.find(t);
1919 if(it1 == old_bottom.end()) continue;
1920
1921 for(it1=old_bottom.equal_range(t).first; it1 != old_bottom.equal_range(t).second; it1++)
1922 if(mn->name() == (*it1).second) break;
1923
1924 assert(it1 != old_bottom.end());
1925 string s = (*it1).second;
1926 for(it2=new_bottom.equal_range(s).first; it2 != new_bottom.equal_range(s).second; it2++)
1927 {
1928 string v = (*it2).second;
1929 if(v.find(mn->bottom(j)) != v.npos)
1930 mn->set_bottom(j, v);
1931 }
1932 }
1933 }
1934 }
1935
create(int mode,string ntgConfig,string solverConfig)1936 void MLEngine::create(int mode, string ntgConfig, string solverConfig)
1937 {
1938 bool parsed = parseMLConfig(ntgConfig, &ntgparam_);
1939 if(!parsed) exit(-1);
1940
1941 if(!solverConfig.empty())
1942 {
1943 parsed = parseSolverConfig(solverConfig, &sparam_);
1944 if(!parsed) exit(-1);
1945
1946 num_epochs_ = sparam_.max_epochs();
1947 current_epoch_ = 0;
1948 current_batch_ = 0;
1949 load_from_checkpoint_ = sparam_.load_checkpoint();
1950 checkpoint_dir_ = sparam_.checkpoint_dir();
1951 checkpoint_format_ = sparam_.checkpoint_format();
1952 data_type_ = sparam_.data_type();
1953 }
1954
1955 #ifdef _OPENMP
1956 num_threads_ = omp_get_max_threads();
1957 #else
1958 num_threads_ = 1;
1959 #endif
1960
1961 printf("Using %d threads\n",num_threads_);
1962
1963 #ifdef USE_MLSL
1964 global_node_id_ = MLSL::Environment::GetEnv().GetProcessIdx();
1965 num_machines_ = MLSL::Environment::GetEnv().GetProcessCount();
1966 data_parallelism = NULL;
1967 if(mode == TRAIN || mode == VAL)
1968 session_ = MLSL::Environment::GetEnv().CreateSession(MLSL::PT_TRAIN);
1969 else
1970 session_ = MLSL::Environment::GetEnv().CreateSession(MLSL::PT_TEST);
1971 #else
1972 global_node_id_ = 0;
1973 num_machines_ = 1;
1974 #endif
1975
1976 // if no training mode in config, then set inferenceOnly_ to true
1977 inferenceOnly_ = (mode == TEST);
1978
1979 // Initialize solver node
1980 int ni = find_in_nodeTypeList("Solver");
1981 solverParams_ = parseSolverParams(&sparam_);
1982 solver_ = new SolverNode(solverParams_, this);
1983
1984 /*************************************************************************************/
1985 /*** Create a global tensor to hold scratch memory needed by Conv layers (LIBXSMM) ***/
1986 /*************************************************************************************/
1987 tenScratch_ = new Tensor("scratch");
1988 tenScratchBuf_ = tenScratch_->getBuf(DATA);
1989 tenScratchBuf_->setBufferPtr(scratch);
1990
1991 NTGParameter split_ntgparam;
1992
1993 insertSplitNodes(ntgparam_, &split_ntgparam);
1994 if(global_node_id_ == 0)
1995 split_ntgparam.PrintDebugString();
1996
1997 int numNodes = split_ntgparam.node_size();
1998
1999 for(int i=0; i<numNodes; i++)
2000 {
2001 // get name and type of each node
2002 // call parse and create node functions based on type
2003 // find member of TypeList
2004 NodeParameter np = split_ntgparam.node(i);
2005 string ntype = np.type();
2006
2007 #ifdef DEBUG
2008 printf("node type %s\n",ntype.c_str());
2009 #endif
2010 int j = find_in_nodeTypeList(ntype);
2011
2012 MLParams *p = nodeTypes[j].parse(&np);
2013 MLNode *node = nodeTypes[j].create(p, this);
2014 ntg_.push_back(node);
2015 #ifdef USE_MLSL
2016 if(ntype.find("Data") != ntype.npos)
2017 data_parallelism = MLSL::Environment::GetEnv().CreateDistribution(num_machines_, 1);
2018 #endif
2019
2020 }
2021
2022 // We assert that the first node in the topology be a data node. Graph creation starts from data node
2023 NNNode* dnode = dynamic_cast<NNNode*>(ntg_[0]);
2024 assert(dnode != NULL);
2025
2026 string first = dnode->getNodeType();
2027 #ifdef DEBUG
2028 printf("first node type %s\n",first.c_str());
2029 #endif
2030 assert(first.find("Data") != first.npos);
2031
2032 // Create the neural network graph for training or testing mode
2033 dnode->createNNGraph(mode);
2034
2035 // Forward Pass Binning.
2036 // Look for tasks attached to nodes with no successors. Add them to the Execution Task Graph (etg) first.
2037 for(int i=numNodes-1; i>0; i--)
2038 {
2039 NNNode *nn = dynamic_cast<NNNode*>(ntg_[i]);
2040 Task* t = nn->getBasicTask(BASIC_TASK_FORW);
2041
2042 if(nn->getNumNextNodes() == 0)
2043 {
2044 etg_[mode].push_back(t);
2045 #ifndef NDEBUG
2046 printf("FP task %p (node %s), bin %d pushed to etg_\n",t, nn->getNodeName().c_str(), t->getMaxBin());
2047 #endif
2048 }
2049 }
2050
2051 // Assign bins to tasks based on their dependencies. Tasks with lower bin number must
2052 // execute before those with higher bin number. Tasks with same bin number can execute in parallel
2053 // Ensure no duplicate tasks in etg
2054 create_schedule(mode);
2055 optimize_schedule(mode);
2056
2057 if(mode == TRAIN)
2058 {
2059 for(auto it = etg_[mode].begin(); it != etg_[mode].end(); it++)
2060 {
2061 Task *t = *it;
2062 if(t->getBasicTaskId() == BASIC_TASK_FORW)
2063 etg_[VAL].push_back(t);
2064 else
2065 break;
2066 }
2067 }
2068
2069 #ifdef DEBUG
2070 for(auto it=etg_[mode].begin(); it != etg_[mode].end(); it++)
2071 {
2072 Task* t = (*it);
2073 string s = dynamic_cast<NNNode*>(t->getNode())->getNodeName();
2074 if(t->getBasicTaskId() == BASIC_TASK_FORW)
2075 printf("FP Task %p in node %s at bin %d\n",t, s.c_str(), t->getMaxBin());
2076 else if(t->getBasicTaskId() == BASIC_TASK_BACK)
2077 printf("BP Task %p in node %s at bin %d\n",t, s.c_str(), t->getMaxBin());
2078 else if(t->getBasicTaskId() == BASIC_TASK_WGRAD)
2079 printf("WG Task %p in node %s at bin %d\n",t, s.c_str(), t->getMaxBin());
2080 else
2081 printf("SOLVER Task %p in node %s at bin %d\n",t, s.c_str(), t->getMaxBin());
2082 }
2083 #endif
2084
2085 if(mode == TRAIN)
2086 printf("Training schedule has %u tasks\n",(unsigned int)etg_[mode].size());
2087 else
2088 printf("Testing schedule has %u tasks\n",(unsigned int)etg_[mode].size());
2089
2090
2091 /*** Allocate memory and set pointers for INPUT and LABEL buffers ***/
2092 /**********************************************************************/
2093 long long int total_input_size;
2094 long long int max_fwd_buffer_size=0;
2095
2096 allocate_memory("INPUT", inTList_, DATA, input_can_ptr, &ic, &total_input_size);
2097 if(global_node_id_ == 0)
2098 printf("Total input memory allocated %lld bytes\n", total_input_size);
2099
2100 /**********************************************************************/
2101 /*** Allocate memory and set pointers for FORWARD ACTIVATION buffer ***/
2102 /**********************************************************************/
2103 long long int total_fact_size;
2104 allocate_memory("FACT", outTList_, DATA, fact_can_ptr, &fac, &total_fact_size);
2105 if(global_node_id_ == 0)
2106 printf("Total forward activation memory allocated %lld bytes\n", total_fact_size);
2107
2108 /***********************************************************/
2109 /*** Allocate memory and set pointers for WEIGHTS buffer ***/
2110 /***********************************************************/
2111 long long int total_weight_size;
2112 allocate_memory("WEIGHT", wTList_, DATA, wt_can_ptr, &wtc, &total_weight_size);
2113 if(global_node_id_ == 0)
2114 printf("Total weights memory allocated %lld bytes\n", total_weight_size);
2115
2116 /***********************************************************/
2117 /*** Allocate memory and set pointers for BIASES buffer ***/
2118 /***********************************************************/
2119 long long int total_bias_size;
2120 allocate_memory("BIAS", biasTList_, DATA, bias_can_ptr, &bic, &total_bias_size);
2121 if(global_node_id_ == 0)
2122 printf("Total bias memory allocated %lld bytes\n", total_bias_size);
2123
2124 /***********************************************************/
2125 /*** Allocate memory and set pointers for STATS buffer ***/
2126 /***********************************************************/
2127 long long int total_stats_size;
2128 allocate_memory("STATS", statsTList_, DATA, stats_can_ptr, &sic, &total_stats_size);
2129 if(global_node_id_ == 0)
2130 printf("Total stats memory allocated %lld bytes\n", total_stats_size);
2131
2132 // Required only for training
2133 long long int total_bp_size;
2134 if(!inferenceOnly_)
2135 {
2136 /***********************************************************************/
2137 /*** Allocate memory and set pointers for BACKWARD ACTIVATION buffer ***/
2138 /***********************************************************************/
2139 #if !defined(USE_OPTBP_ALLOC)
2140 long long int total_bact_size;
2141 allocate_memory("BACT", outTList_, DIFF, bact_can_ptr, &bac, &total_bact_size);
2142 if(global_node_id_ == 0)
2143 printf("Total backward activation memory allocated %lld bytes\n", total_bact_size);
2144 #else
2145 long long int total_bact_size = NDIFFS * max_fwd_buffer_size;
2146 allocate_gradient_tensor(outTList_, DIFF, NDIFFS, max_fwd_buffer_size);
2147 if(global_node_id_ == 0)
2148 printf("Total backward activation memory allocated %lld bytes\n", total_bact_size);
2149 #endif
2150
2151 /********************************************************************/
2152 /*** Allocate memory and set pointers for WEIGHT GRADIENTS buffer ***/
2153 /********************************************************************/
2154 long long int total_wdiff_size;
2155 allocate_memory("WEIGHT", wTList_, DIFF, wdiff_can_ptr, &wdc, &total_wdiff_size);
2156 if(global_node_id_ == 0)
2157 printf("Total weight gradient memory allocated %lld bytes\n", total_wdiff_size);
2158
2159 /*********************************************************************/
2160 /*** Allocate memory and set pointers for WEIGHT INCREMENTS buffer ***/
2161 /*********************************************************************/
2162 long long int total_winc_size;
2163 allocate_memory("WEIGHT", wTList_, HISTORY, winc_can_ptr, &wic, &total_winc_size);
2164 if(global_node_id_ == 0)
2165 printf("Total weight increment memory allocated %lld bytes\n", total_winc_size);
2166
2167 /********************************************************************/
2168 /*** Allocate memory and set pointers for BIAS GRADIENTS buffer ***/
2169 /********************************************************************/
2170 long long int total_bidiff_size;
2171 allocate_memory("BIAS", biasTList_, DIFF, bidiff_can_ptr, &bidc, &total_bidiff_size);
2172 if(global_node_id_ == 0)
2173 printf("Total bias gradient memory allocated %lld bytes\n", total_bidiff_size);
2174
2175 /*********************************************************************/
2176 /*** Allocate memory and set pointers for BIAS INCREMENTS buffer ***/
2177 /*********************************************************************/
2178 long long int total_biinc_size;
2179 allocate_memory("BIAS", biasTList_, HISTORY, biinc_can_ptr, &biic, &total_biinc_size);
2180 if(global_node_id_ == 0)
2181 printf("Total bias increment memory allocated %lld bytes\n", total_biinc_size);
2182
2183 total_bp_size = total_bact_size + total_wdiff_size + total_winc_size + total_bidiff_size + total_biinc_size;
2184 }
2185
2186 long long int total_memory = total_input_size + total_fact_size + total_weight_size + total_bias_size + total_bp_size;
2187 if(global_node_id_ == 0)
2188 printf("Total tensor memory = %lld\n",total_memory);
2189 }
2190