1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
15 #include "rnn.h"
16 
17 #include <math.h>
18 
19 namespace ncnn {
20 
RNN()21 RNN::RNN()
22 {
23     one_blob_only = false;
24     support_inplace = false;
25 }
26 
load_param(const ParamDict & pd)27 int RNN::load_param(const ParamDict& pd)
28 {
29     num_output = pd.get(0, 0);
30     weight_data_size = pd.get(1, 0);
31     direction = pd.get(2, 0);
32     return 0;
33 }
34 
load_model(const ModelBin & mb)35 int RNN::load_model(const ModelBin& mb)
36 {
37     int num_directions = direction == 2 ? 2 : 1;
38 
39     int size = weight_data_size / num_directions / num_output;
40 
41     // raw weight data
42     weight_xc_data = mb.load(size, num_output, num_directions, 0);
43     if (weight_xc_data.empty())
44         return -100;
45 
46     bias_c_data = mb.load(num_output, 1, num_directions, 0);
47     if (bias_c_data.empty())
48         return -100;
49 
50     weight_hc_data = mb.load(num_output, num_output, num_directions, 0);
51     if (weight_hc_data.empty())
52         return -100;
53 
54     return 0;
55 }
56 
rnn(const Mat & bottom_blob,Mat & top_blob,int reverse,const Mat & weight_xc,const Mat & bias_c,const Mat & weight_hc,Mat & hidden_state,const Option & opt)57 static int rnn(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, const Option& opt)
58 {
59     int size = bottom_blob.w;
60     int T = bottom_blob.h;
61 
62     int num_output = top_blob.w;
63 
64     // num_output
65     Mat gates(num_output, 4u, opt.workspace_allocator);
66     if (gates.empty())
67         return -100;
68 
69     // unroll
70     for (int t = 0; t < T; t++)
71     {
72         int ti = reverse ? T - 1 - t : t;
73 
74         const float* x = bottom_blob.row(ti);
75 
76         for (int q = 0; q < num_output; q++)
77         {
78             const float* weight_xc_ptr = weight_xc.row(q);
79 
80             const float* weight_hc_ptr = weight_hc.row(q);
81 
82             float H = bias_c[q];
83 
84             for (int i = 0; i < size; i++)
85             {
86                 H += weight_xc_ptr[i] * x[i];
87             }
88 
89             for (int i = 0; i < num_output; i++)
90             {
91                 H += weight_hc_ptr[i] * hidden_state[i];
92             }
93 
94             H = tanh(H);
95 
96             gates[q] = H;
97         }
98 
99         float* output_data = top_blob.row(ti);
100         for (int q = 0; q < num_output; q++)
101         {
102             float H = gates[q];
103 
104             hidden_state[q] = H;
105             output_data[q] = H;
106         }
107     }
108 
109     return 0;
110 }
111 
forward(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const112 int RNN::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
113 {
114     int T = bottom_blob.h;
115 
116     int num_directions = direction == 2 ? 2 : 1;
117 
118     // initial hidden state
119     Mat hidden(num_output, 4u, opt.workspace_allocator);
120     if (hidden.empty())
121         return -100;
122     hidden.fill(0.f);
123 
124     top_blob.create(num_output * num_directions, T, 4u, opt.blob_allocator);
125     if (top_blob.empty())
126         return -100;
127 
128     // Uni directional
129     if (direction == 0 || direction == 1)
130     {
131         int ret = rnn(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, opt);
132         if (ret != 0)
133             return ret;
134     }
135 
136     if (direction == 2)
137     {
138         Mat top_blob_forward(num_output, T, 4u, opt.workspace_allocator);
139         if (top_blob_forward.empty())
140             return -100;
141 
142         Mat top_blob_reverse(num_output, T, 4u, opt.workspace_allocator);
143         if (top_blob_reverse.empty())
144             return -100;
145 
146         int ret0 = rnn(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, opt);
147         if (ret0 != 0)
148             return ret0;
149 
150         hidden.fill(0.0f);
151 
152         int ret1 = rnn(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden, opt);
153         if (ret1 != 0)
154             return ret1;
155 
156         // concat w
157         for (int i = 0; i < T; i++)
158         {
159             const float* pf = top_blob_forward.row(i);
160             const float* pr = top_blob_reverse.row(i);
161             float* ptr = top_blob.row(i);
162 
163             memcpy(ptr, pf, num_output * sizeof(float));
164             memcpy(ptr + num_output, pr, num_output * sizeof(float));
165         }
166     }
167 
168     return 0;
169 }
170 
forward(const std::vector<Mat> & bottom_blobs,std::vector<Mat> & top_blobs,const Option & opt) const171 int RNN::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
172 {
173     const Mat& bottom_blob = bottom_blobs[0];
174     int T = bottom_blob.h;
175     int num_directions = direction == 2 ? 2 : 1;
176 
177     Mat hidden;
178     Allocator* hidden_allocator = top_blobs.size() == 2 ? opt.blob_allocator : opt.workspace_allocator;
179     if (bottom_blobs.size() == 2)
180     {
181         hidden = bottom_blobs[1].clone(hidden_allocator);
182     }
183     else
184     {
185         hidden.create(num_output, num_directions, 4u, hidden_allocator);
186         if (hidden.empty())
187             return -100;
188         hidden.fill(0.f);
189     }
190 
191     Mat& top_blob = top_blobs[0];
192     top_blob.create(num_output * num_directions, T, 4u, opt.blob_allocator);
193     if (top_blob.empty())
194         return -100;
195 
196     // Uni directional
197     if (direction == 0 || direction == 1)
198     {
199         int ret = rnn(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, opt);
200         if (ret != 0)
201             return ret;
202     }
203 
204     if (direction == 2)
205     {
206         Mat top_blob_forward(num_output, T, 4u, opt.workspace_allocator);
207         if (top_blob_forward.empty())
208             return -100;
209 
210         Mat top_blob_reverse(num_output, T, 4u, opt.workspace_allocator);
211         if (top_blob_reverse.empty())
212             return -100;
213 
214         Mat hidden0 = hidden.row_range(0, 1);
215         int ret0 = rnn(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden0, opt);
216         if (ret0 != 0)
217             return ret0;
218 
219         Mat hidden1 = hidden.row_range(1, 1);
220         int ret1 = rnn(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden1, opt);
221         if (ret1 != 0)
222             return ret1;
223 
224         // concat w
225         for (int i = 0; i < T; i++)
226         {
227             const float* pf = top_blob_forward.row(i);
228             const float* pr = top_blob_reverse.row(i);
229             float* ptr = top_blob.row(i);
230 
231             memcpy(ptr, pf, num_output * sizeof(float));
232             memcpy(ptr + num_output, pr, num_output * sizeof(float));
233         }
234     }
235 
236     if (top_blobs.size() == 2)
237     {
238         top_blobs[1] = hidden;
239     }
240 
241     return 0;
242 }
243 
244 } // namespace ncnn
245