1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved.                      *
3 * This file is part of the LIBXSMM library.                                   *
4 *                                                                             *
5 * For information on the license, see the LICENSE file.                       *
6 * Further information: https://github.com/hfp/libxsmm/                        *
7 * SPDX-License-Identifier: BSD-3-Clause                                       *
8 ******************************************************************************/
9 /* Sasikanth Avancha, Dhiraj Kalamkar (Intel Corp.)
10 ******************************************************************************/
11 
12 
13 #pragma once
14 #include <string>
15 #include <stdio.h>
16 #include "Node.hpp"
17 #include "Engine.hpp"
18 #include "Params.hpp"
19 #include "Tensor.hpp"
20 #include "proto/gxm.pb.h"
21 #include "ReLUImpl.hpp"
22 #include "ReLUXSMM.hpp"
23 
24 using namespace std;
25 using namespace gxm;
26 
27 class ReLUParams : public NNParams
28 {
29   public:
ReLUParams(void)30     ReLUParams(void) {}
31 
~ReLUParams(void)32     virtual ~ReLUParams(void) {}
33 
set_negative_slope(float s)34     void set_negative_slope(float s) { neg_slope_ = s; }
get_negative_slope()35     float get_negative_slope() { return neg_slope_; }
36 
set_data_type(int t)37     void set_data_type(int t) { data_type_ = t; }
get_data_type()38     int get_data_type() { return data_type_; }
39 
set_compute_engine(int ce)40     void set_compute_engine(int ce) { compute_engine_ = ce; }
get_compute_engine()41     int get_compute_engine() { return compute_engine_; }
42 
set_algo_type(int at)43     void set_algo_type(int at) { algotype_ = at; }
get_algo_type()44     int get_algo_type() { return algotype_; }
45 
46   protected:
47     float neg_slope_;
48     int compute_engine_, algotype_, data_type_;
49 };
50 
parseReLUParams(NodeParameter * np)51 static MLParams* parseReLUParams(NodeParameter* np)
52 {
53   ReLUParams* rp = new ReLUParams();
54 
55   // Set name of node
56   string str = np->name();
57   assert(!str.empty());
58   rp->set_node_name(str);
59 
60   //Set node type (ReLU)
61   str = np->type();
62   assert(!str.empty());
63   rp->set_node_type(str);
64 
65   //Set tensor names
66   assert(np->bottom_size() == 1);
67   assert(!np->bottom(0).empty());
68   rp->set_bottom_names(np->bottom(0));
69 
70   assert(np->top_size() == 1);
71   assert(!np->top(0).empty());
72   rp->set_top_names(np->top(0));
73 
74   //Set Mode for the node
75   assert((np->mode() == TRAIN) || (np->mode() == TEST));
76   rp->set_mode(np->mode());
77 
78   //Set backprop needed/not needed flag for this node
79   rp->set_bprop_flag(np->propagate_down());
80 
81   ReLUParameter p = np->relu_param();
82 
83   rp->set_negative_slope(p.negative_slope());
84 
85   rp->set_data_type(p.data_type());
86   rp->set_compute_engine(p.engine());
87   rp->set_algo_type(p.algotype());
88 
89   return rp;
90 }
91 
92 class ReLUNode : public NNNode
93 {
94   public:
95     ReLUNode(ReLUParams* p, MLEngine* e);
96 
~ReLUNode(void)97     virtual ~ReLUNode(void) {}
98 
99   protected:
100     void forwardPropagate();
101     void backPropagate();
102     void configure(int engine);
103 
shape_setzero(Shape * s)104     void shape_setzero(Shape* s)
105     {
106       for(int i=0; i<MAX_DIMS; i++)
107         s->dims[i] = 0;
108     }
109 
110     Tensor* tenTop_; // Output tensor pointer
111     Tensor* tenBot_; // Input tensor pointer
112     ReLUImplParams gparams_;
113     TensorBuf *tenBotDiff_, *tenBotData_; // Data & Gradients with respect to input
114     TensorBuf *tenTopData_, *tenTopDiff_; // Output data
115 
116     int count_;
117 
118     int bot_cengine_;
119     Shape ts_;
120     ReLUImpl *impl;
121     MLEngine* eptr_;
122 };
123