1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
15 #include "rnn.h"
16
17 #include <math.h>
18
19 namespace ncnn {
20
RNN()21 RNN::RNN()
22 {
23 one_blob_only = false;
24 support_inplace = false;
25 }
26
load_param(const ParamDict & pd)27 int RNN::load_param(const ParamDict& pd)
28 {
29 num_output = pd.get(0, 0);
30 weight_data_size = pd.get(1, 0);
31 direction = pd.get(2, 0);
32 return 0;
33 }
34
load_model(const ModelBin & mb)35 int RNN::load_model(const ModelBin& mb)
36 {
37 int num_directions = direction == 2 ? 2 : 1;
38
39 int size = weight_data_size / num_directions / num_output;
40
41 // raw weight data
42 weight_xc_data = mb.load(size, num_output, num_directions, 0);
43 if (weight_xc_data.empty())
44 return -100;
45
46 bias_c_data = mb.load(num_output, 1, num_directions, 0);
47 if (bias_c_data.empty())
48 return -100;
49
50 weight_hc_data = mb.load(num_output, num_output, num_directions, 0);
51 if (weight_hc_data.empty())
52 return -100;
53
54 return 0;
55 }
56
rnn(const Mat & bottom_blob,Mat & top_blob,int reverse,const Mat & weight_xc,const Mat & bias_c,const Mat & weight_hc,Mat & hidden_state,const Option & opt)57 static int rnn(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, const Option& opt)
58 {
59 int size = bottom_blob.w;
60 int T = bottom_blob.h;
61
62 int num_output = top_blob.w;
63
64 // num_output
65 Mat gates(num_output, 4u, opt.workspace_allocator);
66 if (gates.empty())
67 return -100;
68
69 // unroll
70 for (int t = 0; t < T; t++)
71 {
72 int ti = reverse ? T - 1 - t : t;
73
74 const float* x = bottom_blob.row(ti);
75
76 for (int q = 0; q < num_output; q++)
77 {
78 const float* weight_xc_ptr = weight_xc.row(q);
79
80 const float* weight_hc_ptr = weight_hc.row(q);
81
82 float H = bias_c[q];
83
84 for (int i = 0; i < size; i++)
85 {
86 H += weight_xc_ptr[i] * x[i];
87 }
88
89 for (int i = 0; i < num_output; i++)
90 {
91 H += weight_hc_ptr[i] * hidden_state[i];
92 }
93
94 H = tanh(H);
95
96 gates[q] = H;
97 }
98
99 float* output_data = top_blob.row(ti);
100 for (int q = 0; q < num_output; q++)
101 {
102 float H = gates[q];
103
104 hidden_state[q] = H;
105 output_data[q] = H;
106 }
107 }
108
109 return 0;
110 }
111
forward(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const112 int RNN::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
113 {
114 int T = bottom_blob.h;
115
116 int num_directions = direction == 2 ? 2 : 1;
117
118 // initial hidden state
119 Mat hidden(num_output, 4u, opt.workspace_allocator);
120 if (hidden.empty())
121 return -100;
122 hidden.fill(0.f);
123
124 top_blob.create(num_output * num_directions, T, 4u, opt.blob_allocator);
125 if (top_blob.empty())
126 return -100;
127
128 // Uni directional
129 if (direction == 0 || direction == 1)
130 {
131 int ret = rnn(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, opt);
132 if (ret != 0)
133 return ret;
134 }
135
136 if (direction == 2)
137 {
138 Mat top_blob_forward(num_output, T, 4u, opt.workspace_allocator);
139 if (top_blob_forward.empty())
140 return -100;
141
142 Mat top_blob_reverse(num_output, T, 4u, opt.workspace_allocator);
143 if (top_blob_reverse.empty())
144 return -100;
145
146 int ret0 = rnn(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, opt);
147 if (ret0 != 0)
148 return ret0;
149
150 hidden.fill(0.0f);
151
152 int ret1 = rnn(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden, opt);
153 if (ret1 != 0)
154 return ret1;
155
156 // concat w
157 for (int i = 0; i < T; i++)
158 {
159 const float* pf = top_blob_forward.row(i);
160 const float* pr = top_blob_reverse.row(i);
161 float* ptr = top_blob.row(i);
162
163 memcpy(ptr, pf, num_output * sizeof(float));
164 memcpy(ptr + num_output, pr, num_output * sizeof(float));
165 }
166 }
167
168 return 0;
169 }
170
forward(const std::vector<Mat> & bottom_blobs,std::vector<Mat> & top_blobs,const Option & opt) const171 int RNN::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
172 {
173 const Mat& bottom_blob = bottom_blobs[0];
174 int T = bottom_blob.h;
175 int num_directions = direction == 2 ? 2 : 1;
176
177 Mat hidden;
178 Allocator* hidden_allocator = top_blobs.size() == 2 ? opt.blob_allocator : opt.workspace_allocator;
179 if (bottom_blobs.size() == 2)
180 {
181 hidden = bottom_blobs[1].clone(hidden_allocator);
182 }
183 else
184 {
185 hidden.create(num_output, num_directions, 4u, hidden_allocator);
186 if (hidden.empty())
187 return -100;
188 hidden.fill(0.f);
189 }
190
191 Mat& top_blob = top_blobs[0];
192 top_blob.create(num_output * num_directions, T, 4u, opt.blob_allocator);
193 if (top_blob.empty())
194 return -100;
195
196 // Uni directional
197 if (direction == 0 || direction == 1)
198 {
199 int ret = rnn(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, opt);
200 if (ret != 0)
201 return ret;
202 }
203
204 if (direction == 2)
205 {
206 Mat top_blob_forward(num_output, T, 4u, opt.workspace_allocator);
207 if (top_blob_forward.empty())
208 return -100;
209
210 Mat top_blob_reverse(num_output, T, 4u, opt.workspace_allocator);
211 if (top_blob_reverse.empty())
212 return -100;
213
214 Mat hidden0 = hidden.row_range(0, 1);
215 int ret0 = rnn(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden0, opt);
216 if (ret0 != 0)
217 return ret0;
218
219 Mat hidden1 = hidden.row_range(1, 1);
220 int ret1 = rnn(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden1, opt);
221 if (ret1 != 0)
222 return ret1;
223
224 // concat w
225 for (int i = 0; i < T; i++)
226 {
227 const float* pf = top_blob_forward.row(i);
228 const float* pr = top_blob_reverse.row(i);
229 float* ptr = top_blob.row(i);
230
231 memcpy(ptr, pf, num_output * sizeof(float));
232 memcpy(ptr + num_output, pr, num_output * sizeof(float));
233 }
234 }
235
236 if (top_blobs.size() == 2)
237 {
238 top_blobs[1] = hidden;
239 }
240
241 return 0;
242 }
243
244 } // namespace ncnn
245