1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved.                      *
3 * This file is part of the LIBXSMM library.                                   *
4 *                                                                             *
5 * For information on the license, see the LICENSE file.                       *
6 * Further information: https://github.com/hfp/libxsmm/                        *
7 * SPDX-License-Identifier: BSD-3-Clause                                       *
8 ******************************************************************************/
9 /* Sasikanth Avancha, Dhiraj Kalamkar (Intel Corp.)
10 ******************************************************************************/
11 
12 
13 #pragma once
14 #include <string>
15 #include <stdio.h>
16 #include "assert.h"
17 #include "Node.hpp"
18 #include "Engine.hpp"
19 #include "Params.hpp"
20 #include "Solver.hpp"
21 #include "common.hpp"
22 #include "io.hpp"
23 #include "proto/gxm.pb.h"
24 #include "FCImpl.hpp"
25 #include "FCXSMM.hpp"
26 
27 using namespace std;
28 using namespace gxm;
29 
30 class FCParams : public NNParams
31 {
32   public:
FCParams(void)33     FCParams(void) {}
~FCParams(void)34     virtual ~FCParams(void) {}
35 
set_nOutput(int num_output)36     void set_nOutput(int num_output)  { this->nOutput_ = num_output;  }
get_output()37     int get_output() { return nOutput_; }
38 
set_activation_filler_type(string ftype)39     void set_activation_filler_type(string ftype) { afiller_type_ = ftype; }
get_activation_filler_type()40     string get_activation_filler_type() { return afiller_type_; }
41 
set_weight_filler_type(string ftype)42     void set_weight_filler_type(string ftype) { wfiller_type_ = ftype; }
get_weight_filler_type()43     string get_weight_filler_type() { return wfiller_type_; }
44 
set_std(float s)45     void set_std(float s) { std_ = s; }
get_std()46     float get_std() { return std_; }
47 
set_variance_norm(int v)48     void set_variance_norm(int v) { variance_norm_ = v; }
get_variance_norm()49     int get_variance_norm() { return variance_norm_; }
50 
set_bias_filler_type(string ftype)51     void set_bias_filler_type(string ftype) { bfiller_type_ = ftype; }
get_bias_filler_type()52     string get_bias_filler_type() { return bfiller_type_; }
53 
set_bias_term(bool bias)54     void set_bias_term(bool bias) { bias_term_ = bias; }
get_bias_term()55     bool get_bias_term() { return bias_term_; }
56 
set_value(float v)57     void set_value(float v) { value_ = v; }
get_value()58     float get_value() { return value_; }
59 
set_timeSteps(int nt)60     void set_timeSteps(int nt) { this->timesteps_ = nt; }
61 
set_transpose_flag(bool xpose)62     void set_transpose_flag(bool xpose) { transpose_ = xpose; }
get_transpose_flag()63     bool get_transpose_flag() { return transpose_; }
64 
set_data_type(int t)65     void set_data_type(int t) { data_type_ = t; }
get_data_type()66     int get_data_type() { return data_type_; }
67 
set_compute_engine(int ce)68     void set_compute_engine(int ce) { compute_engine_ = ce; }
get_compute_engine()69     int get_compute_engine() { return compute_engine_; }
70 
set_algo_type(int at)71     void set_algo_type(int at) { algotype_ = at; }
get_algo_type()72     int get_algo_type() { return algotype_; }
73 
set_global_params(vector<ParamSpec> psv)74     void set_global_params(vector<ParamSpec> psv)
75     {
76       for(int i=0; i<psv.size(); i++)
77       {
78         lr_mult_.push_back(psv[i].lr_mult());
79         decay_mult_.push_back(psv[i].decay_mult());
80       }
81     }
get_lr_mult()82     const vector<float>& get_lr_mult() { return lr_mult_; }
get_decay_mult()83     const vector<float>& get_decay_mult() { return decay_mult_; }
84 
85   protected:
86     int nOutput_, data_type_;
87     int timesteps_, compute_engine_, algotype_;
88     int variance_norm_;
89     bool transpose_;
90     string wfiller_type_, bfiller_type_, afiller_type_;
91     float std_, value_;
92     bool bias_term_;
93     vector<float> lr_mult_, decay_mult_;
94 };
95 
parseFCParams(NodeParameter * np)96 static MLParams* parseFCParams(NodeParameter* np)
97 {
98   FCParams* fcp = new FCParams();
99 
100   // Set name of node
101   assert(!np->name().empty());
102   fcp->set_node_name(np->name());
103 
104   //Set node type (Convolution, FullyConnected, etc)
105   assert(!np->type().empty());
106   fcp->set_node_type(np->type());
107 
108   //Set tensor names
109   assert(np->bottom_size() == 1);
110   assert(!np->bottom(0).empty());
111   fcp->set_bottom_names(np->bottom(0));
112 
113   assert(np->top_size() == 1);
114   assert(!np->top(0).empty());
115   fcp->set_top_names(np->top(0));
116 
117   //Set Mode for the node
118   assert((np->mode() == TRAIN) || (np->mode() == TEST));
119   fcp->set_mode(np->mode());
120 
121   //Set backprop needed/not needed flag for this node
122   fcp->set_bprop_flag(np->propagate_down());
123 
124   // Set global parameters such as learning rate multiplier etc.
125   vector<ParamSpec> psv;
126   for(int i=0; i<np->param_size(); i++)
127     psv.push_back(np->param(i));
128   fcp->set_global_params(psv);
129 
130   FullyConnectedParameter pfcp = np->fc_param();
131 
132   int num_output = pfcp.num_output();
133   fcp->set_nOutput(num_output);
134 
135   FillerParameter wp = pfcp.weight_filler();
136   fcp->set_weight_filler_type(wp.type());
137   fcp->set_std(wp.std());
138   fcp->set_variance_norm(wp.variance_norm());
139 
140   bool bias_term = pfcp.bias_term();
141   fcp->set_bias_term(bias_term);
142 
143   if(bias_term)
144   {
145     FillerParameter bp = pfcp.bias_filler();
146     fcp->set_bias_filler_type(bp.type());
147     fcp->set_value(bp.value());
148   }
149 
150   bool xpose = pfcp.transpose();
151   if(xpose)
152     fcp->set_transpose_flag(xpose);
153 
154   bool activation_term = pfcp.activation_term();
155   if(activation_term)
156   {
157     FillerParameter ap = pfcp.activation_filler();
158     fcp->set_activation_filler_type(ap.type());
159     fcp->set_value(ap.value());
160   }
161 
162   int nt = pfcp.num_timesteps();
163   fcp->set_timeSteps(nt);
164 
165   fcp->set_data_type(pfcp.data_type());
166   fcp->set_compute_engine(pfcp.engine());
167   fcp->set_algo_type(pfcp.algotype());
168 
169   return fcp;
170 }
171 
172 class FCNode: public NNNode
173 {
174   public:
175     FCNode(FCParams *p, MLEngine* e);
176 
~FCNode(void)177     virtual ~FCNode(void) {}
178 
get_weight_filler_type()179     string get_weight_filler_type() { return wfiller_type_; }
get_std()180     float get_std() { return std_; }
181 
get_bias_filler_type()182     string get_bias_filler_type() { return bfiller_type_; }
get_value()183     float get_value() { return value_; }
184 
185     void fillWeightBuffers(TensorBuf* tBuf, int buftype, long long int size);
186     void fillWeightMultipliers(float* lr_mult, float* decay_mult, long long int bytes);
187     void fillBiasBuffers(TensorBuf* tBuf, int buftype, long long int size);
188     void fillBiasMultipliers(float *lr_mult, float *decay_mult, long long int bytes);
189     void Checkpoint(TensorBuf *ptr, string name, string format);
190     void convert_bf16_f32(libxsmm_bfloat16*, float*, int);
191     void convert_f32_bf16(float*, libxsmm_bfloat16*, int);
192 
193   protected:
194     void forwardPropagate();
195     void backPropagate();
196     void weightUpdate();
197     void solverStep();
198     void truncate_mask_fp32_bfp16(float* in, float* out, unsigned int len);
shape_setzero(Shape * s)199     void shape_setzero(Shape* s)
200     {
201       for(int i=0; i<MAX_DIMS; i++)
202         s->dims[i] = 0;
203     }
204 
205     void configure(int engine);
206 
207     Tensor *tenTop_=NULL; // Output tensor pointer
208     Tensor *tenBot_=NULL; // Input tensor pointer
209     Tensor *tenWeight_=NULL; // Weight tensor pointer
210     Tensor *tenBias_=NULL;
211     FCImplParams gparams_;
212     TensorBuf *tenBotDiff_=NULL, *tenBotData_=NULL;
213     TensorBuf *tenTopData_=NULL, *tenTopDiff_=NULL;
214     TensorBuf *tenWeightDiff_=NULL, *tenWeightData_=NULL, *tenWeightInc_=NULL;
215     TensorBuf *tenBiasData_=NULL, *tenBiasDiff_=NULL, *tenBiasInc_=NULL;
216     TensorBuf *tenScratchData_=NULL;
217     Shape bs_, ts_, ws_;
218 
219     int bot_cengine_;
220 
221     int count_;
222 
223     string wfiller_type_, bfiller_type_;
224     string weight_, bias_;
225     float std_, value_;
226     int variance_norm_;
227     float *stptr=NULL, cbptr[16];
228     int in_dtype, out_dtype;
229     float *dwptr=NULL;
230 
231     vector<float> lr_mult_, decay_mult_;
232 
233     FCImpl* impl;
234     SolverNode* solver_;
235     MLEngine* eptr_;
236 };
237 
238