1 /****************************************************************************** 2 * Copyright (c) Intel Corporation - All rights reserved. * 3 * This file is part of the LIBXSMM library. * 4 * * 5 * For information on the license, see the LICENSE file. * 6 * Further information: https://github.com/hfp/libxsmm/ * 7 * SPDX-License-Identifier: BSD-3-Clause * 8 ******************************************************************************/ 9 /* Sasikanth Avancha, Dhiraj Kalamkar (Intel Corp.) 10 ******************************************************************************/ 11 12 13 #pragma once 14 #include <string> 15 #include <vector> 16 #include <algorithm> 17 #include <list> 18 #include <algorithm> 19 #include "Params.hpp" 20 #include "MLNode.hpp" 21 #include "Engine.fwd.hpp" 22 #include "Task.hpp" 23 #include "proto/gxm.pb.h" 24 25 using namespace std; 26 using namespace gxm; 27 #ifdef USE_MLSL 28 #include "mlsl.hpp" 29 #endif 30 31 class NNParams : public MLParams 32 { 33 protected: 34 vector<string> top_; 35 vector<string> bottom_; 36 string nname_; 37 string type_; 38 int mode_; 39 bool bp_flag_; 40 41 public: NNParams(void)42 NNParams(void) {} ~NNParams(void)43 virtual ~NNParams(void) {} 44 set_top_names(string name)45 void set_top_names(string name) { top_.push_back(name); } set_bottom_names(string name)46 void set_bottom_names(string name) { bottom_.push_back(name); } set_node_name(string nname)47 void set_node_name(string nname) { nname_ = nname; } set_node_type(string type)48 void set_node_type(string type) {type_ = type; } set_mode(int mode)49 void set_mode(int mode) { mode_ = mode; } set_bprop_flag(bool flag)50 void set_bprop_flag(bool flag) { bp_flag_ = flag; } 51 get_node_name()52 string get_node_name() { return nname_; } get_top_names()53 vector<string>& get_top_names() { return top_; } get_bottom_names()54 vector<string>& get_bottom_names() { return bottom_; } get_node_type()55 string get_node_type() { return type_; } get_mode()56 int get_mode() { return mode_; } get_bprop_flag()57 bool get_bprop_flag() { return bp_flag_; } 58 }; 59 60 class NNNode : public MLNode 61 { 62 public: NNNode(NNParams * p,MLEngine * e)63 NNNode(NNParams* p, MLEngine* e) : MLNode(p, e) 64 { 65 for(int i = 0; i < 4; i++) tBasic_[i] = NULL; 66 } 67 ~NNNode(void)68 virtual ~NNNode(void) 69 { 70 for(int i = 0; i < 4; i++) if(tBasic_[i] != NULL) { delete tBasic_[i]; tBasic_[i] = NULL; } 71 } 72 createTasks(list<Task * >,int)73 void createTasks(list<Task*>, int) {} createStrategy(int)74 virtual void createStrategy(int) {} 75 forwardPropagate()76 virtual void forwardPropagate() {} backPropagate()77 virtual void backPropagate() {} weightUpdate()78 virtual void weightUpdate() {} solverStep()79 virtual void solverStep() {} 80 executeTask(int taskId)81 int executeTask(int taskId) 82 { 83 if(taskId == 0) 84 { 85 forwardPropagate(); 86 } 87 else if(taskId == 1) 88 { 89 backPropagate(); 90 } 91 else if(taskId == 2) 92 { 93 weightUpdate(); 94 } 95 else if(taskId == 3) 96 { 97 solverStep(); 98 } 99 return 0; 100 } 101 enqueTask(int pos)102 void enqueTask(int pos) {} 103 createPersistentTask()104 virtual void createPersistentTask() {} 105 setNextNode(NNNode * next)106 void setNextNode(NNNode* next) 107 { 108 //check if next is already in the nextNodes list 109 if(std::find(nextNodes_.begin(), nextNodes_.end(), next) == nextNodes_.end()) 110 { 111 nextNodes_.push_back(next); 112 next->prevNodes_.push_back(this); 113 } 114 } 115 setPrevNode(NNNode * prev)116 void setPrevNode(NNNode* prev) 117 { 118 //check if prev is already in the prevNodes list 119 if(std::find(prevNodes_.begin(), prevNodes_.end(), prev) == prevNodes_.end()) 120 { 121 prevNodes_.push_back(prev); 122 prev->nextNodes_.push_back(this); 123 } 124 } 125 getBasicTask(int type)126 Task *getBasicTask(int type) 127 { 128 int index = -1; 129 if(type == 0 || (type == 1 && bp_flag_) || (type > 1 && has_weights_)) 130 index = type; 131 if(index != -1) { 132 if(tBasic_[index] == NULL) tBasic_[index] = new Task(this, -1, type); 133 return tBasic_[index]; 134 } 135 return NULL; 136 } 137 138 void createNNGraph(int mode); 139 setNodeType(string type)140 void setNodeType(string type) { ntype_ = type; } 141 getNodeType()142 string getNodeType() { return ntype_; } getNodeName()143 string getNodeName() { return nname_; } getMode()144 int getMode() { return mode_; } 145 getNumPrevNodes()146 int getNumPrevNodes() { return prevNodes_.size(); } getNumNextNodes()147 int getNumNextNodes() { return nextNodes_.size(); } 148 getPrevNode(int i)149 NNNode* getPrevNode(int i) { if(prevNodes_.size() > 0) return prevNodes_[i]; else return NULL; } getNextNode(int i)150 NNNode* getNextNode(int i) { if(nextNodes_.size() > 0) return nextNodes_[i]; else return NULL; } 151 get_num_tops()152 int get_num_tops() { return top_.size(); } set_top_compute_engine(int e)153 void set_top_compute_engine(int e) { top_compute_engine_ = e; } get_bot_compute_engine()154 int get_bot_compute_engine() { return bot_compute_engine_; } set_next_node_type(string s)155 void set_next_node_type(string s) {next_ntype_ = s;} 156 refineTask()157 void refineTask(){} 158 createCheckPoint()159 virtual void createCheckPoint() {} restoreCheckPoint()160 virtual void restoreCheckPoint() {} 161 162 protected: 163 string nname_, ntype_, next_ntype_; 164 vector<string> top_; 165 vector<string> bottom_; 166 int mode_; 167 bool bp_flag_; 168 bool has_weights_; 169 vector<NNNode*> prevNodes_; 170 vector<NNNode*> nextNodes_; 171 int top_compute_engine_, bot_compute_engine_; 172 #ifdef USE_MLSL 173 MLSL::Operation* op_; 174 #endif 175 176 177 // 0-Forw, 1-Back, 2-WGrad, 3-Solver 178 Task *tBasic_[4]; 179 }; 180 181