1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved.                      *
3 * This file is part of the LIBXSMM library.                                   *
4 *                                                                             *
5 * For information on the license, see the LICENSE file.                       *
6 * Further information: https://github.com/hfp/libxsmm/                        *
7 * SPDX-License-Identifier: BSD-3-Clause                                       *
8 ******************************************************************************/
9 /* Sasikanth Avancha, Dhiraj Kalamkar (Intel Corp.)
10 ******************************************************************************/
11 
12 
13 #pragma once
14 #include <map>
15 #include <list>
16 #include <vector>
17 #include <algorithm>
18 #include <set>
19 #include <omp.h>
20 #include <sys/time.h>
21 #include <stdlib.h>
22 #include "proto/gxm.pb.h"
23 #include "Engine.fwd.hpp"
24 #include "MLNode.fwd.hpp"
25 #include "Config.hpp"
26 #include "Task.hpp"
27 #include "common.hpp"
28 #include "Solver.hpp"
29 #include "libxsmm.h"
30 #ifdef USE_MLSL
31 #include "mpi.h"
32 #endif
33 
34 using namespace std;
35 using namespace gxm;
36 
37 extern int iter;
38 
39 #ifdef USE_MLSL
40 #include "mlsl.hpp"
41 //using namespace MLSL;
42 #endif
43 
44 #define TRAIN 0
45 #define VAL 1
46 #define TEST 2
47 #define START_GUARD_BAND 64
48 #define END_GUARD_BAND 64
49 #define CANARY 0x7F
50 #define NDIFFS 10
51 
52 struct dupChecker_ {
dupChecker_dupChecker_53   inline dupChecker_() : tmpSet() {}
operator ()dupChecker_54   inline bool operator()(Task *t) {
55     return tmpSet.insert(t).second;
56   }
57   private:
58     std::set<Task *> tmpSet;
59 };
60 
61 class MLEngine
62 {
63   protected:
64     NTGParameter ntgparam_;
65     NodeParameter np_;
66     SolverParameter sparam_;
67 #ifdef USE_MLSL
68     MLSL::Distribution *data_parallelism;
69     MLSL::Session *session_;
70 #endif
71     vector<MLNode*> ntg_;
72     list<Task*> etg_[3]; // 0 - Training, 1 - Validation, 2 - testing
73     SolverParams *solverParams_;
74     SolverNode* solver_;
75     Tensor* tenScratch_;
76     TensorBuf* tenScratchBuf_;
77 
78     struct TensorPair
79     {
80       string name;
81       Tensor* t;
82     };
83     typedef list<TensorPair> TensorList;
84     typedef TensorList::iterator Iter;
85     typedef map<string, Iter> Tmap;
86 
87     Tmap inTensorMap_, outTensorMap_, weightTensorMap_, biasTensorMap_, statsTensorMap_;
88     TensorList defTList_, inTList_, outTList_, wTList_, biasTList_, statsTList_;
89 
90     bool inferenceOnly_, load_from_checkpoint_;
91     string checkpoint_dir_, checkpoint_format_;
92     int num_epochs_, exec_mode_, current_epoch_, current_batch_;
93     int data_type_;
94     int num_machines_, num_machine_groups_, num_threads_;
95     int batch_size_, num_train_batches_, num_test_batches_, num_test_views_;
96     int global_node_id_;
97     float lr_, *wt_lr_mult_[NUM_NUMA_NODES], *wt_decay_mult_[NUM_NUMA_NODES];
98     float *bias_lr_mult_[NUM_NUMA_NODES], *bias_decay_mult_[NUM_NUMA_NODES];
99     float scf_=0;
100 
101     void *input_buf_=NULL;
102     void *fact_buf_=NULL, *bact_buf_=NULL, *wbuf_=NULL;
103     void *weight_buf_[NUM_NUMA_NODES]={NULL}, *wdiff_buf_[NUM_NUMA_NODES]={NULL};
104     void *winc_buf_[NUM_NUMA_NODES]={NULL}, *lpweight_buf_[NUM_NUMA_NODES]={NULL};
105     void *lpwdiff_buf_[NUM_NUMA_NODES]={NULL};
106 #if 1
107     void *bias_buf_[NUM_NUMA_NODES]={NULL}, *bidiff_buf_[NUM_NUMA_NODES]={NULL};
108     void *biinc_buf_[NUM_NUMA_NODES]={NULL}, *stats_buf_[NUM_NUMA_NODES]={NULL};
109 #else
110     void *bias_buf_=NULL, *bidiff_buf_=NULL;
111     void *biinc_buf_=NULL, *stats_buf_=NULL;
112 #endif
113     int total_weights_, total_biases_, orig_total_weights_;
114     void *scratch[NUM_NUMA_NODES]={NULL};
115 
116     vector<int> input_can_ptr;
117     vector<int> fact_can_ptr, bact_can_ptr;
118     vector<int> wt_can_ptr, wdiff_can_ptr, winc_can_ptr;
119     vector<int> bias_can_ptr, stats_can_ptr, bidiff_can_ptr, biinc_can_ptr;
120 #ifdef USE_MLSL
121     vector<MLSL::Operation*> wtgrad_comms_vec, bias_grad_comms_vec, combo_grad_comms_vec;
122 #endif
123     int ic, fac, bac, wtc, wdc, wic, bic, sic, bidc, biic;
124 
125     void create_schedule(int);
126     void optimize_schedule(int);
127     void allocate_tensor_memory(Tensor*, int, void*);
128     void clear_history(TensorList);
129     int find_in_nodeTypeList(string);
130     void checkpoint(TensorList L, int);
131     void read_checkpoint_file(TensorBuf*, string, string);
132     void load_checkpoint(TensorList, int, string);
133     void canary_check(void*, vector<int>&, int);
134     void allocate_memory(string, TensorList, int, vector<int>&, int*, long long int*);
135     void* allocate_gradient_tensor(TensorList, int, int, long long int);
136     void insertSplitNodes(NTGParameter& p, NTGParameter* ps);
137     void convert_f32_bf16(float* in, libxsmm_bfloat16* out, int len, int numa_node);
138     void convert_f32_bf16(float** in, libxsmm_bfloat16** out, int len);
139     void convert_bf16_f32(libxsmm_bfloat16* in, float* out, int len);
140     void waitForComms(string);
141 
142   public:
MLEngine()143     MLEngine() {}
~MLEngine()144     virtual ~MLEngine() {}
145 
146     void create(int mode, string ntgConfig, string solverConfig);
147     bool register_tensor(string name, int type, Tensor* t);
148     Tensor* get_tensor(string name, int type);
149     void execute_on_thread(int num_threads, MLNode* node, void (*fname)(int tid));
150     void set_global_strategy(MachineParameter* mparam);
151     void run(int mode);
152 
getSolver()153     SolverNode* getSolver() { return solver_; }
getScratchBuffer()154     TensorBuf* getScratchBuffer() { return tenScratchBuf_; }
155 
is_inference_only()156     bool is_inference_only() { return inferenceOnly_; }
157 
get_num_threads()158     int get_num_threads() { return num_threads_; }
get_num_machines()159     int get_num_machines() { return num_machines_; }
get_num_machine_groups()160     int get_num_machine_groups() { return num_machine_groups_; }
get_num_epochs()161     int get_num_epochs() { return num_epochs_;}
get_current_epoch()162     int get_current_epoch() { return current_epoch_; }
get_current_batch()163     int get_current_batch() { return current_batch_; }
get_execution_mode()164     int get_execution_mode() { return exec_mode_; }
get_global_node_id()165     int get_global_node_id() { return global_node_id_; }
get_num_train_batches()166     int get_num_train_batches() { return num_train_batches_; }
get_num_test_batches()167     int get_num_test_batches() { return num_test_batches_; }
get_num_test_views()168     int get_num_test_views() {return num_test_views_; }
get_batch_size()169     int get_batch_size() { return batch_size_; }
get_scaling_factor()170     float get_scaling_factor() { return scf_; }
171 #ifdef USE_MLSL
get_wtgrad_comms_vec()172     vector<MLSL::Operation*>& get_wtgrad_comms_vec() { return wtgrad_comms_vec; }
get_bias_grad_comms_vec()173     vector<MLSL::Operation*>& get_bias_grad_comms_vec() { return bias_grad_comms_vec; }
get_combo_grad_comms_vec()174     vector<MLSL::Operation*>& get_combo_grad_comms_vec() { return combo_grad_comms_vec; }
175 #endif
176 
set_batch_size(int b)177     void set_batch_size(int b) {batch_size_ = b; }
set_num_train_batches(int ntrainb)178     void set_num_train_batches(int ntrainb) {num_train_batches_ = ntrainb; }
set_num_test_batches(int ntestb)179     void set_num_test_batches(int ntestb) {num_test_batches_ = ntestb; }
set_num_test_views(int ntestv)180     void set_num_test_views(int ntestv) {num_test_views_ = ntestv; }
set_learning_rate(float lr)181     void set_learning_rate(float lr) { lr_ = lr; }
set_scaling_factor(float scf)182     void set_scaling_factor(float scf) { scf_ = scf; }
183 #ifdef USE_MLSL
get_distribution()184     MLSL::Distribution* get_distribution() { return data_parallelism; }
get_session()185     MLSL::Session *get_session() { return session_; }
186 #endif
187 
188 };
189 
190