1 // Tencent is pleased to support the open source community by making ncnn available. 2 // 3 // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. 4 // 5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except 6 // in compliance with the License. You may obtain a copy of the License at 7 // 8 // https://opensource.org/licenses/BSD-3-Clause 9 // 10 // Unless required by applicable law or agreed to in writing, software distributed 11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the 13 // specific language governing permissions and limitations under the License. 14 15 #include "pass_ncnn.h" 16 17 namespace pnnx { 18 19 namespace ncnn { 20 21 class nn_RNN : public GraphRewriterPass 22 { 23 public: match_pattern_graph() const24 const char* match_pattern_graph() const 25 { 26 return R"PNNXIR(7767517 27 3 3 28 pnnx.Input input 0 1 input 29 nn.RNN op_0 1 2 input out out_hidden input_size=%input_size hidden_size=%hidden_size num_layers=1 nonlinearity=%nonlinearity bias=%bias batch_first=%batch_first bidirectional=%bidirectional @weight_ih_l0 @weight_hh_l0 @bias_ih_l0 @bias_hh_l0 @weight_ih_l0_reverse @weight_hh_l0_reverse @bias_ih_l0_reverse @bias_hh_l0_reverse 30 pnnx.Output output 2 0 out out_hidden 31 )PNNXIR"; 32 } 33 type_str() const34 const char* type_str() const 35 { 36 return "RNN"; 37 } 38 name_str() const39 const char* name_str() const 40 { 41 return "rnn"; 42 } 43 write(Operator * op,const std::map<std::string,Parameter> & captured_params,const std::map<std::string,Attribute> & captured_attrs) const44 void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const 45 { 46 const std::string nonlinearity = captured_params.at("nonlinearity").s; 47 48 if (nonlinearity != "tanh") 49 { 50 fprintf(stderr, "RNN nonlinearity=%s not supported\n", nonlinearity.c_str()); 51 } 52 53 const bool bidirectional = captured_params.at("bidirectional").b; 54 const int num_directions = bidirectional ? 2 : 1; 55 const int num_output = captured_params.at("hidden_size").i; 56 const int input_size = captured_params.at("input_size").i; 57 58 int weight_data_size = num_directions * num_output * input_size; 59 60 op->params["0"] = num_output; 61 op->params["1"] = weight_data_size; 62 op->params["2"] = bidirectional ? 2 : 0; 63 64 op->attrs["0"] = Attribute(); 65 op->attrs["0"].data = {0, 0, 0, 0}; 66 op->attrs["1"] = captured_attrs.at("op_0.weight_ih_l0"); 67 if (bidirectional) 68 op->attrs["2"] = captured_attrs.at("op_0.weight_ih_l0_reverse"); 69 70 op->attrs["3"] = Attribute(); 71 op->attrs["3"].data = {0, 0, 0, 0}; 72 if (captured_params.at("bias").b) 73 { 74 // reduce bias_ih and bias_hh 75 std::vector<float> new_bias; 76 { 77 const float* bias_ih = (const float*)captured_attrs.at("op_0.bias_ih_l0").data.data(); 78 const float* bias_hh = (const float*)captured_attrs.at("op_0.bias_hh_l0").data.data(); 79 80 new_bias.resize(num_output); 81 float* bias = (float*)new_bias.data(); 82 for (int i = 0; i < num_output; i++) 83 { 84 bias[i] = bias_ih[i] + bias_hh[i]; 85 } 86 } 87 88 op->attrs["4"] = Attribute({num_output}, new_bias); 89 90 if (bidirectional) 91 { 92 std::vector<float> new_bias_reverse; 93 { 94 const float* bias_ih = (const float*)captured_attrs.at("op_0.bias_ih_l0_reverse").data.data(); 95 const float* bias_hh = (const float*)captured_attrs.at("op_0.bias_hh_l0_reverse").data.data(); 96 97 new_bias_reverse.resize(num_output); 98 float* bias = (float*)new_bias_reverse.data(); 99 for (int i = 0; i < num_output; i++) 100 { 101 bias[i] = bias_ih[i] + bias_hh[i]; 102 } 103 } 104 105 op->attrs["5"] = Attribute({num_output}, new_bias_reverse); 106 } 107 } 108 else 109 { 110 std::vector<float> bias(num_output, 0.f); 111 op->attrs["4"] = Attribute({num_output}, bias); 112 113 if (bidirectional) 114 { 115 op->attrs["5"] = Attribute({num_output}, bias); 116 } 117 } 118 119 op->attrs["6"] = Attribute(); 120 op->attrs["6"].data = {0, 0, 0, 0}; 121 op->attrs["7"] = captured_attrs.at("op_0.weight_hh_l0"); 122 if (bidirectional) 123 op->attrs["8"] = captured_attrs.at("op_0.weight_hh_l0_reverse"); 124 } 125 }; 126 127 REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_RNN, 20) 128 129 class nn_RNN_1 : public nn_RNN 130 { 131 public: match_pattern_graph() const132 const char* match_pattern_graph() const 133 { 134 return R"PNNXIR(7767517 135 4 4 136 pnnx.Input input 0 1 input 137 pnnx.Input in_hidden 0 1 in_hidden 138 nn.RNN op_0 2 2 input in_hidden out out_hidden input_size=%input_size hidden_size=%hidden_size num_layers=1 nonlinearity=%nonlinearity bias=%bias batch_first=%batch_first bidirectional=%bidirectional @weight_ih_l0 @weight_hh_l0 @bias_ih_l0 @bias_hh_l0 @weight_ih_l0_reverse @weight_hh_l0_reverse @bias_ih_l0_reverse @bias_hh_l0_reverse 139 pnnx.Output output 2 0 out out_hidden 140 )PNNXIR"; 141 } 142 }; 143 144 REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_RNN_1, 20) 145 146 class nn_RNN_2 : public nn_RNN 147 { 148 public: match_pattern_graph() const149 const char* match_pattern_graph() const 150 { 151 return R"PNNXIR(7767517 152 3 2 153 pnnx.Input input 0 1 input 154 nn.RNN op_0 1 1 input out input_size=%input_size hidden_size=%hidden_size num_layers=1 nonlinearity=%nonlinearity bias=%bias batch_first=%batch_first bidirectional=%bidirectional @weight_ih_l0 @weight_hh_l0 @bias_ih_l0 @bias_hh_l0 @weight_ih_l0_reverse @weight_hh_l0_reverse @bias_ih_l0_reverse @bias_hh_l0_reverse 155 pnnx.Output output 1 0 out 156 )PNNXIR"; 157 } 158 }; 159 160 REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_RNN_2, 20) 161 162 class nn_RNN_3 : public nn_RNN 163 { 164 public: match_pattern_graph() const165 const char* match_pattern_graph() const 166 { 167 return R"PNNXIR(7767517 168 4 3 169 pnnx.Input input 0 1 input 170 pnnx.Input in_hidden 0 1 in_hidden 171 nn.RNN op_0 2 1 input in_hidden out input_size=%input_size hidden_size=%hidden_size num_layers=1 nonlinearity=%nonlinearity bias=%bias batch_first=%batch_first bidirectional=%bidirectional @weight_ih_l0 @weight_hh_l0 @bias_ih_l0 @bias_hh_l0 @weight_ih_l0_reverse @weight_hh_l0_reverse @bias_ih_l0_reverse @bias_hh_l0_reverse 172 pnnx.Output output 1 0 out 173 )PNNXIR"; 174 } 175 }; 176 177 REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_RNN_3, 20) 178 179 } // namespace ncnn 180 181 } // namespace pnnx 182