1 /****************************************************************************** 2 * Copyright (c) Intel Corporation - All rights reserved. * 3 * This file is part of the LIBXSMM library. * 4 * * 5 * For information on the license, see the LICENSE file. * 6 * Further information: https://github.com/hfp/libxsmm/ * 7 * SPDX-License-Identifier: BSD-3-Clause * 8 ******************************************************************************/ 9 /* Sasikanth Avancha, Dhiraj Kalamkar (Intel Corp.) 10 ******************************************************************************/ 11 12 13 #pragma once 14 #include <map> 15 #include <list> 16 #include <vector> 17 #include <algorithm> 18 #include <set> 19 #include <omp.h> 20 #include <sys/time.h> 21 #include <stdlib.h> 22 #include "proto/gxm.pb.h" 23 #include "Engine.fwd.hpp" 24 #include "MLNode.fwd.hpp" 25 #include "Config.hpp" 26 #include "Task.hpp" 27 #include "common.hpp" 28 #include "Solver.hpp" 29 #include "libxsmm.h" 30 #ifdef USE_MLSL 31 #include "mpi.h" 32 #endif 33 34 using namespace std; 35 using namespace gxm; 36 37 extern int iter; 38 39 #ifdef USE_MLSL 40 #include "mlsl.hpp" 41 //using namespace MLSL; 42 #endif 43 44 #define TRAIN 0 45 #define VAL 1 46 #define TEST 2 47 #define START_GUARD_BAND 64 48 #define END_GUARD_BAND 64 49 #define CANARY 0x7F 50 #define NDIFFS 10 51 52 struct dupChecker_ { dupChecker_dupChecker_53 inline dupChecker_() : tmpSet() {} operator ()dupChecker_54 inline bool operator()(Task *t) { 55 return tmpSet.insert(t).second; 56 } 57 private: 58 std::set<Task *> tmpSet; 59 }; 60 61 class MLEngine 62 { 63 protected: 64 NTGParameter ntgparam_; 65 NodeParameter np_; 66 SolverParameter sparam_; 67 #ifdef USE_MLSL 68 MLSL::Distribution *data_parallelism; 69 MLSL::Session *session_; 70 #endif 71 vector<MLNode*> ntg_; 72 list<Task*> etg_[3]; // 0 - Training, 1 - Validation, 2 - testing 73 SolverParams *solverParams_; 74 SolverNode* solver_; 75 Tensor* tenScratch_; 76 TensorBuf* tenScratchBuf_; 77 78 struct TensorPair 79 { 80 string name; 81 Tensor* t; 82 }; 83 typedef list<TensorPair> TensorList; 84 typedef TensorList::iterator Iter; 85 typedef map<string, Iter> Tmap; 86 87 Tmap inTensorMap_, outTensorMap_, weightTensorMap_, biasTensorMap_, statsTensorMap_; 88 TensorList defTList_, inTList_, outTList_, wTList_, biasTList_, statsTList_; 89 90 bool inferenceOnly_, load_from_checkpoint_; 91 string checkpoint_dir_, checkpoint_format_; 92 int num_epochs_, exec_mode_, current_epoch_, current_batch_; 93 int data_type_; 94 int num_machines_, num_machine_groups_, num_threads_; 95 int batch_size_, num_train_batches_, num_test_batches_, num_test_views_; 96 int global_node_id_; 97 float lr_, *wt_lr_mult_[NUM_NUMA_NODES], *wt_decay_mult_[NUM_NUMA_NODES]; 98 float *bias_lr_mult_[NUM_NUMA_NODES], *bias_decay_mult_[NUM_NUMA_NODES]; 99 float scf_=0; 100 101 void *input_buf_=NULL; 102 void *fact_buf_=NULL, *bact_buf_=NULL, *wbuf_=NULL; 103 void *weight_buf_[NUM_NUMA_NODES]={NULL}, *wdiff_buf_[NUM_NUMA_NODES]={NULL}; 104 void *winc_buf_[NUM_NUMA_NODES]={NULL}, *lpweight_buf_[NUM_NUMA_NODES]={NULL}; 105 void *lpwdiff_buf_[NUM_NUMA_NODES]={NULL}; 106 #if 1 107 void *bias_buf_[NUM_NUMA_NODES]={NULL}, *bidiff_buf_[NUM_NUMA_NODES]={NULL}; 108 void *biinc_buf_[NUM_NUMA_NODES]={NULL}, *stats_buf_[NUM_NUMA_NODES]={NULL}; 109 #else 110 void *bias_buf_=NULL, *bidiff_buf_=NULL; 111 void *biinc_buf_=NULL, *stats_buf_=NULL; 112 #endif 113 int total_weights_, total_biases_, orig_total_weights_; 114 void *scratch[NUM_NUMA_NODES]={NULL}; 115 116 vector<int> input_can_ptr; 117 vector<int> fact_can_ptr, bact_can_ptr; 118 vector<int> wt_can_ptr, wdiff_can_ptr, winc_can_ptr; 119 vector<int> bias_can_ptr, stats_can_ptr, bidiff_can_ptr, biinc_can_ptr; 120 #ifdef USE_MLSL 121 vector<MLSL::Operation*> wtgrad_comms_vec, bias_grad_comms_vec, combo_grad_comms_vec; 122 #endif 123 int ic, fac, bac, wtc, wdc, wic, bic, sic, bidc, biic; 124 125 void create_schedule(int); 126 void optimize_schedule(int); 127 void allocate_tensor_memory(Tensor*, int, void*); 128 void clear_history(TensorList); 129 int find_in_nodeTypeList(string); 130 void checkpoint(TensorList L, int); 131 void read_checkpoint_file(TensorBuf*, string, string); 132 void load_checkpoint(TensorList, int, string); 133 void canary_check(void*, vector<int>&, int); 134 void allocate_memory(string, TensorList, int, vector<int>&, int*, long long int*); 135 void* allocate_gradient_tensor(TensorList, int, int, long long int); 136 void insertSplitNodes(NTGParameter& p, NTGParameter* ps); 137 void convert_f32_bf16(float* in, libxsmm_bfloat16* out, int len, int numa_node); 138 void convert_f32_bf16(float** in, libxsmm_bfloat16** out, int len); 139 void convert_bf16_f32(libxsmm_bfloat16* in, float* out, int len); 140 void waitForComms(string); 141 142 public: MLEngine()143 MLEngine() {} ~MLEngine()144 virtual ~MLEngine() {} 145 146 void create(int mode, string ntgConfig, string solverConfig); 147 bool register_tensor(string name, int type, Tensor* t); 148 Tensor* get_tensor(string name, int type); 149 void execute_on_thread(int num_threads, MLNode* node, void (*fname)(int tid)); 150 void set_global_strategy(MachineParameter* mparam); 151 void run(int mode); 152 getSolver()153 SolverNode* getSolver() { return solver_; } getScratchBuffer()154 TensorBuf* getScratchBuffer() { return tenScratchBuf_; } 155 is_inference_only()156 bool is_inference_only() { return inferenceOnly_; } 157 get_num_threads()158 int get_num_threads() { return num_threads_; } get_num_machines()159 int get_num_machines() { return num_machines_; } get_num_machine_groups()160 int get_num_machine_groups() { return num_machine_groups_; } get_num_epochs()161 int get_num_epochs() { return num_epochs_;} get_current_epoch()162 int get_current_epoch() { return current_epoch_; } get_current_batch()163 int get_current_batch() { return current_batch_; } get_execution_mode()164 int get_execution_mode() { return exec_mode_; } get_global_node_id()165 int get_global_node_id() { return global_node_id_; } get_num_train_batches()166 int get_num_train_batches() { return num_train_batches_; } get_num_test_batches()167 int get_num_test_batches() { return num_test_batches_; } get_num_test_views()168 int get_num_test_views() {return num_test_views_; } get_batch_size()169 int get_batch_size() { return batch_size_; } get_scaling_factor()170 float get_scaling_factor() { return scf_; } 171 #ifdef USE_MLSL get_wtgrad_comms_vec()172 vector<MLSL::Operation*>& get_wtgrad_comms_vec() { return wtgrad_comms_vec; } get_bias_grad_comms_vec()173 vector<MLSL::Operation*>& get_bias_grad_comms_vec() { return bias_grad_comms_vec; } get_combo_grad_comms_vec()174 vector<MLSL::Operation*>& get_combo_grad_comms_vec() { return combo_grad_comms_vec; } 175 #endif 176 set_batch_size(int b)177 void set_batch_size(int b) {batch_size_ = b; } set_num_train_batches(int ntrainb)178 void set_num_train_batches(int ntrainb) {num_train_batches_ = ntrainb; } set_num_test_batches(int ntestb)179 void set_num_test_batches(int ntestb) {num_test_batches_ = ntestb; } set_num_test_views(int ntestv)180 void set_num_test_views(int ntestv) {num_test_views_ = ntestv; } set_learning_rate(float lr)181 void set_learning_rate(float lr) { lr_ = lr; } set_scaling_factor(float scf)182 void set_scaling_factor(float scf) { scf_ = scf; } 183 #ifdef USE_MLSL get_distribution()184 MLSL::Distribution* get_distribution() { return data_parallelism; } get_session()185 MLSL::Session *get_session() { return session_; } 186 #endif 187 188 }; 189 190