1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
15 #include "pass_ncnn.h"
16 
17 namespace pnnx {
18 
19 namespace ncnn {
20 
21 class nn_RNN : public GraphRewriterPass
22 {
23 public:
match_pattern_graph() const24     const char* match_pattern_graph() const
25     {
26         return R"PNNXIR(7767517
27 3 3
28 pnnx.Input              input       0 1 input
29 nn.RNN                  op_0        1 2 input out out_hidden input_size=%input_size hidden_size=%hidden_size num_layers=1 nonlinearity=%nonlinearity bias=%bias batch_first=%batch_first bidirectional=%bidirectional @weight_ih_l0 @weight_hh_l0 @bias_ih_l0 @bias_hh_l0 @weight_ih_l0_reverse @weight_hh_l0_reverse @bias_ih_l0_reverse @bias_hh_l0_reverse
30 pnnx.Output             output      2 0 out out_hidden
31 )PNNXIR";
32     }
33 
type_str() const34     const char* type_str() const
35     {
36         return "RNN";
37     }
38 
name_str() const39     const char* name_str() const
40     {
41         return "rnn";
42     }
43 
write(Operator * op,const std::map<std::string,Parameter> & captured_params,const std::map<std::string,Attribute> & captured_attrs) const44     void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
45     {
46         const std::string nonlinearity = captured_params.at("nonlinearity").s;
47 
48         if (nonlinearity != "tanh")
49         {
50             fprintf(stderr, "RNN nonlinearity=%s not supported\n", nonlinearity.c_str());
51         }
52 
53         const bool bidirectional = captured_params.at("bidirectional").b;
54         const int num_directions = bidirectional ? 2 : 1;
55         const int num_output = captured_params.at("hidden_size").i;
56         const int input_size = captured_params.at("input_size").i;
57 
58         int weight_data_size = num_directions * num_output * input_size;
59 
60         op->params["0"] = num_output;
61         op->params["1"] = weight_data_size;
62         op->params["2"] = bidirectional ? 2 : 0;
63 
64         op->attrs["0"] = Attribute();
65         op->attrs["0"].data = {0, 0, 0, 0};
66         op->attrs["1"] = captured_attrs.at("op_0.weight_ih_l0");
67         if (bidirectional)
68             op->attrs["2"] = captured_attrs.at("op_0.weight_ih_l0_reverse");
69 
70         op->attrs["3"] = Attribute();
71         op->attrs["3"].data = {0, 0, 0, 0};
72         if (captured_params.at("bias").b)
73         {
74             // reduce bias_ih and bias_hh
75             std::vector<float> new_bias;
76             {
77                 const float* bias_ih = (const float*)captured_attrs.at("op_0.bias_ih_l0").data.data();
78                 const float* bias_hh = (const float*)captured_attrs.at("op_0.bias_hh_l0").data.data();
79 
80                 new_bias.resize(num_output);
81                 float* bias = (float*)new_bias.data();
82                 for (int i = 0; i < num_output; i++)
83                 {
84                     bias[i] = bias_ih[i] + bias_hh[i];
85                 }
86             }
87 
88             op->attrs["4"] = Attribute({num_output}, new_bias);
89 
90             if (bidirectional)
91             {
92                 std::vector<float> new_bias_reverse;
93                 {
94                     const float* bias_ih = (const float*)captured_attrs.at("op_0.bias_ih_l0_reverse").data.data();
95                     const float* bias_hh = (const float*)captured_attrs.at("op_0.bias_hh_l0_reverse").data.data();
96 
97                     new_bias_reverse.resize(num_output);
98                     float* bias = (float*)new_bias_reverse.data();
99                     for (int i = 0; i < num_output; i++)
100                     {
101                         bias[i] = bias_ih[i] + bias_hh[i];
102                     }
103                 }
104 
105                 op->attrs["5"] = Attribute({num_output}, new_bias_reverse);
106             }
107         }
108         else
109         {
110             std::vector<float> bias(num_output, 0.f);
111             op->attrs["4"] = Attribute({num_output}, bias);
112 
113             if (bidirectional)
114             {
115                 op->attrs["5"] = Attribute({num_output}, bias);
116             }
117         }
118 
119         op->attrs["6"] = Attribute();
120         op->attrs["6"].data = {0, 0, 0, 0};
121         op->attrs["7"] = captured_attrs.at("op_0.weight_hh_l0");
122         if (bidirectional)
123             op->attrs["8"] = captured_attrs.at("op_0.weight_hh_l0_reverse");
124     }
125 };
126 
127 REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_RNN, 20)
128 
129 class nn_RNN_1 : public nn_RNN
130 {
131 public:
match_pattern_graph() const132     const char* match_pattern_graph() const
133     {
134         return R"PNNXIR(7767517
135 4 4
136 pnnx.Input              input       0 1 input
137 pnnx.Input              in_hidden   0 1 in_hidden
138 nn.RNN                  op_0        2 2 input in_hidden out out_hidden input_size=%input_size hidden_size=%hidden_size num_layers=1 nonlinearity=%nonlinearity bias=%bias batch_first=%batch_first bidirectional=%bidirectional @weight_ih_l0 @weight_hh_l0 @bias_ih_l0 @bias_hh_l0 @weight_ih_l0_reverse @weight_hh_l0_reverse @bias_ih_l0_reverse @bias_hh_l0_reverse
139 pnnx.Output             output      2 0 out out_hidden
140 )PNNXIR";
141     }
142 };
143 
144 REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_RNN_1, 20)
145 
146 class nn_RNN_2 : public nn_RNN
147 {
148 public:
match_pattern_graph() const149     const char* match_pattern_graph() const
150     {
151         return R"PNNXIR(7767517
152 3 2
153 pnnx.Input              input       0 1 input
154 nn.RNN                  op_0        1 1 input out input_size=%input_size hidden_size=%hidden_size num_layers=1 nonlinearity=%nonlinearity bias=%bias batch_first=%batch_first bidirectional=%bidirectional @weight_ih_l0 @weight_hh_l0 @bias_ih_l0 @bias_hh_l0 @weight_ih_l0_reverse @weight_hh_l0_reverse @bias_ih_l0_reverse @bias_hh_l0_reverse
155 pnnx.Output             output      1 0 out
156 )PNNXIR";
157     }
158 };
159 
160 REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_RNN_2, 20)
161 
162 class nn_RNN_3 : public nn_RNN
163 {
164 public:
match_pattern_graph() const165     const char* match_pattern_graph() const
166     {
167         return R"PNNXIR(7767517
168 4 3
169 pnnx.Input              input       0 1 input
170 pnnx.Input              in_hidden   0 1 in_hidden
171 nn.RNN                  op_0        2 1 input in_hidden out input_size=%input_size hidden_size=%hidden_size num_layers=1 nonlinearity=%nonlinearity bias=%bias batch_first=%batch_first bidirectional=%bidirectional @weight_ih_l0 @weight_hh_l0 @bias_ih_l0 @bias_hh_l0 @weight_ih_l0_reverse @weight_hh_l0_reverse @bias_ih_l0_reverse @bias_hh_l0_reverse
172 pnnx.Output             output      1 0 out
173 )PNNXIR";
174     }
175 };
176 
177 REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_RNN_3, 20)
178 
179 } // namespace ncnn
180 
181 } // namespace pnnx
182