1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved.                      *
3 * This file is part of the LIBXSMM library.                                   *
4 *                                                                             *
5 * For information on the license, see the LICENSE file.                       *
6 * Further information: https://github.com/hfp/libxsmm/                        *
7 * SPDX-License-Identifier: BSD-3-Clause                                       *
8 ******************************************************************************/
9 /* Sasikanth Avancha, Dhiraj Kalamkar (Intel Corp.)
10 ******************************************************************************/
11 
12 
13 #pragma once
14 #include <string>
15 #include <vector>
16 #include <algorithm>
17 #include <list>
18 #include <algorithm>
19 #include "Params.hpp"
20 #include "MLNode.hpp"
21 #include "Engine.fwd.hpp"
22 #include "Task.hpp"
23 #include "proto/gxm.pb.h"
24 
25 using namespace std;
26 using namespace gxm;
27 #ifdef USE_MLSL
28 #include "mlsl.hpp"
29 #endif
30 
31 class NNParams : public MLParams
32 {
33   protected:
34     vector<string> top_;
35     vector<string> bottom_;
36     string nname_;
37     string type_;
38     int mode_;
39     bool bp_flag_;
40 
41   public:
NNParams(void)42     NNParams(void) {}
~NNParams(void)43     virtual ~NNParams(void) {}
44 
set_top_names(string name)45     void set_top_names(string name) { top_.push_back(name); }
set_bottom_names(string name)46     void set_bottom_names(string name) { bottom_.push_back(name); }
set_node_name(string nname)47     void set_node_name(string nname) { nname_ = nname; }
set_node_type(string type)48     void set_node_type(string type) {type_ = type; }
set_mode(int mode)49     void set_mode(int mode) { mode_ = mode; }
set_bprop_flag(bool flag)50     void set_bprop_flag(bool flag) { bp_flag_ = flag; }
51 
get_node_name()52     string get_node_name() { return nname_; }
get_top_names()53     vector<string>& get_top_names() { return top_; }
get_bottom_names()54     vector<string>& get_bottom_names() { return bottom_; }
get_node_type()55     string get_node_type() { return type_; }
get_mode()56     int get_mode() { return mode_; }
get_bprop_flag()57     bool get_bprop_flag() { return bp_flag_; }
58 };
59 
60 class NNNode : public MLNode
61 {
62   public:
NNNode(NNParams * p,MLEngine * e)63     NNNode(NNParams* p, MLEngine* e) : MLNode(p, e)
64     {
65       for(int i = 0; i < 4; i++) tBasic_[i] = NULL;
66     }
67 
~NNNode(void)68     virtual ~NNNode(void)
69     {
70       for(int i = 0; i < 4; i++) if(tBasic_[i] != NULL) { delete tBasic_[i]; tBasic_[i] = NULL; }
71     }
72 
createTasks(list<Task * >,int)73     void createTasks(list<Task*>, int) {}
createStrategy(int)74     virtual void createStrategy(int) {}
75 
forwardPropagate()76     virtual void forwardPropagate() {}
backPropagate()77     virtual void backPropagate() {}
weightUpdate()78     virtual void weightUpdate() {}
solverStep()79     virtual void solverStep() {}
80 
executeTask(int taskId)81     int executeTask(int taskId)
82     {
83       if(taskId == 0)
84       {
85         forwardPropagate();
86       }
87       else if(taskId == 1)
88       {
89         backPropagate();
90       }
91       else if(taskId == 2)
92       {
93         weightUpdate();
94       }
95       else if(taskId == 3)
96       {
97         solverStep();
98       }
99       return 0;
100     }
101 
enqueTask(int pos)102     void enqueTask(int pos) {}
103 
createPersistentTask()104     virtual void createPersistentTask() {}
105 
setNextNode(NNNode * next)106     void setNextNode(NNNode* next)
107     {
108       //check if next is already in the nextNodes list
109       if(std::find(nextNodes_.begin(), nextNodes_.end(), next) == nextNodes_.end())
110       {
111         nextNodes_.push_back(next);
112         next->prevNodes_.push_back(this);
113       }
114     }
115 
setPrevNode(NNNode * prev)116     void setPrevNode(NNNode* prev)
117     {
118       //check if prev is already in the prevNodes list
119       if(std::find(prevNodes_.begin(), prevNodes_.end(), prev) == prevNodes_.end())
120       {
121         prevNodes_.push_back(prev);
122         prev->nextNodes_.push_back(this);
123       }
124     }
125 
getBasicTask(int type)126     Task *getBasicTask(int type)
127     {
128       int index = -1;
129       if(type == 0 || (type == 1 && bp_flag_) || (type > 1 && has_weights_))
130         index = type;
131       if(index != -1) {
132         if(tBasic_[index] == NULL) tBasic_[index] = new Task(this, -1, type);
133         return tBasic_[index];
134       }
135       return NULL;
136     }
137 
138     void createNNGraph(int mode);
139 
setNodeType(string type)140     void setNodeType(string type) { ntype_ = type; }
141 
getNodeType()142     string getNodeType() { return ntype_; }
getNodeName()143     string getNodeName() { return nname_; }
getMode()144     int getMode() { return mode_; }
145 
getNumPrevNodes()146     int getNumPrevNodes() { return prevNodes_.size(); }
getNumNextNodes()147     int getNumNextNodes() { return nextNodes_.size(); }
148 
getPrevNode(int i)149     NNNode* getPrevNode(int i) { if(prevNodes_.size() > 0) return prevNodes_[i]; else return NULL; }
getNextNode(int i)150     NNNode* getNextNode(int i) { if(nextNodes_.size() > 0) return nextNodes_[i]; else return NULL; }
151 
get_num_tops()152     int get_num_tops() { return top_.size(); }
set_top_compute_engine(int e)153     void set_top_compute_engine(int e) { top_compute_engine_ = e; }
get_bot_compute_engine()154     int get_bot_compute_engine() { return bot_compute_engine_; }
set_next_node_type(string s)155     void set_next_node_type(string s) {next_ntype_ = s;}
156 
refineTask()157     void refineTask(){}
158 
createCheckPoint()159     virtual void createCheckPoint() {}
restoreCheckPoint()160     virtual void restoreCheckPoint() {}
161 
162   protected:
163     string nname_, ntype_, next_ntype_;
164     vector<string> top_;
165     vector<string> bottom_;
166     int mode_;
167     bool bp_flag_;
168     bool has_weights_;
169     vector<NNNode*> prevNodes_;
170     vector<NNNode*> nextNodes_;
171     int top_compute_engine_, bot_compute_engine_;
172 #ifdef USE_MLSL
173     MLSL::Operation* op_;
174 #endif
175 
176 
177     // 0-Forw, 1-Back, 2-WGrad, 3-Solver
178     Task *tBasic_[4];
179 };
180 
181