1 /****************************************************************************** 2 * Copyright (c) Intel Corporation - All rights reserved. * 3 * This file is part of the LIBXSMM library. * 4 * * 5 * For information on the license, see the LICENSE file. * 6 * Further information: https://github.com/hfp/libxsmm/ * 7 * SPDX-License-Identifier: BSD-3-Clause * 8 ******************************************************************************/ 9 /* Sasikanth Avancha, Dhiraj Kalamkar (Intel Corp.) 10 ******************************************************************************/ 11 12 13 #pragma once 14 #include <string> 15 #include <stdio.h> 16 #include "Node.hpp" 17 #include "Engine.hpp" 18 #include "Params.hpp" 19 #include "Tensor.hpp" 20 #include "proto/gxm.pb.h" 21 #include "ReLUImpl.hpp" 22 #include "ReLUXSMM.hpp" 23 24 using namespace std; 25 using namespace gxm; 26 27 class ReLUParams : public NNParams 28 { 29 public: ReLUParams(void)30 ReLUParams(void) {} 31 ~ReLUParams(void)32 virtual ~ReLUParams(void) {} 33 set_negative_slope(float s)34 void set_negative_slope(float s) { neg_slope_ = s; } get_negative_slope()35 float get_negative_slope() { return neg_slope_; } 36 set_data_type(int t)37 void set_data_type(int t) { data_type_ = t; } get_data_type()38 int get_data_type() { return data_type_; } 39 set_compute_engine(int ce)40 void set_compute_engine(int ce) { compute_engine_ = ce; } get_compute_engine()41 int get_compute_engine() { return compute_engine_; } 42 set_algo_type(int at)43 void set_algo_type(int at) { algotype_ = at; } get_algo_type()44 int get_algo_type() { return algotype_; } 45 46 protected: 47 float neg_slope_; 48 int compute_engine_, algotype_, data_type_; 49 }; 50 parseReLUParams(NodeParameter * np)51static MLParams* parseReLUParams(NodeParameter* np) 52 { 53 ReLUParams* rp = new ReLUParams(); 54 55 // Set name of node 56 string str = np->name(); 57 assert(!str.empty()); 58 rp->set_node_name(str); 59 60 //Set node type (ReLU) 61 str = np->type(); 62 assert(!str.empty()); 63 rp->set_node_type(str); 64 65 //Set tensor names 66 assert(np->bottom_size() == 1); 67 assert(!np->bottom(0).empty()); 68 rp->set_bottom_names(np->bottom(0)); 69 70 assert(np->top_size() == 1); 71 assert(!np->top(0).empty()); 72 rp->set_top_names(np->top(0)); 73 74 //Set Mode for the node 75 assert((np->mode() == TRAIN) || (np->mode() == TEST)); 76 rp->set_mode(np->mode()); 77 78 //Set backprop needed/not needed flag for this node 79 rp->set_bprop_flag(np->propagate_down()); 80 81 ReLUParameter p = np->relu_param(); 82 83 rp->set_negative_slope(p.negative_slope()); 84 85 rp->set_data_type(p.data_type()); 86 rp->set_compute_engine(p.engine()); 87 rp->set_algo_type(p.algotype()); 88 89 return rp; 90 } 91 92 class ReLUNode : public NNNode 93 { 94 public: 95 ReLUNode(ReLUParams* p, MLEngine* e); 96 ~ReLUNode(void)97 virtual ~ReLUNode(void) {} 98 99 protected: 100 void forwardPropagate(); 101 void backPropagate(); 102 void configure(int engine); 103 shape_setzero(Shape * s)104 void shape_setzero(Shape* s) 105 { 106 for(int i=0; i<MAX_DIMS; i++) 107 s->dims[i] = 0; 108 } 109 110 Tensor* tenTop_; // Output tensor pointer 111 Tensor* tenBot_; // Input tensor pointer 112 ReLUImplParams gparams_; 113 TensorBuf *tenBotDiff_, *tenBotData_; // Data & Gradients with respect to input 114 TensorBuf *tenTopData_, *tenTopDiff_; // Output data 115 116 int count_; 117 118 int bot_cengine_; 119 Shape ts_; 120 ReLUImpl *impl; 121 MLEngine* eptr_; 122 }; 123