1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
15 #include "lstm.h"
16
17 #include <math.h>
18
19 namespace ncnn {
20
LSTM()21 LSTM::LSTM()
22 {
23 one_blob_only = false;
24 support_inplace = false;
25 }
26
load_param(const ParamDict & pd)27 int LSTM::load_param(const ParamDict& pd)
28 {
29 num_output = pd.get(0, 0);
30 weight_data_size = pd.get(1, 0);
31 direction = pd.get(2, 0);
32 if (direction == 2)
33 one_blob_only = true;
34 return 0;
35 }
36
load_model(const ModelBin & mb)37 int LSTM::load_model(const ModelBin& mb)
38 {
39 int num_directions = direction == 2 ? 2 : 1;
40
41 int size = weight_data_size / num_directions / num_output / 4;
42
43 // raw weight data
44 weight_xc_data = mb.load(size, num_output * 4, num_directions, 0);
45 if (weight_xc_data.empty())
46 return -100;
47
48 bias_c_data = mb.load(num_output, 4, num_directions, 0);
49 if (bias_c_data.empty())
50 return -100;
51
52 weight_hc_data = mb.load(num_output, num_output * 4, num_directions, 0);
53 if (weight_hc_data.empty())
54 return -100;
55
56 return 0;
57 }
58
lstm(const Mat & bottom_blob,Mat & top_blob,int reverse,const Mat & weight_xc,const Mat & bias_c,const Mat & weight_hc,Mat & hidden_state,Mat & cell_state,const Option & opt)59 static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, Mat& cell_state, const Option& opt)
60 {
61 int size = bottom_blob.w;
62 int T = bottom_blob.h;
63
64 int num_output = top_blob.w;
65
66 // 4 x num_output
67 Mat gates(4, num_output, 4u, opt.workspace_allocator);
68 if (gates.empty())
69 return -100;
70
71 // unroll
72 for (int t = 0; t < T; t++)
73 {
74 // clip hidden by continuation indicator
75 // h_cont_{t-1} = cont_t * h_{t-1}
76 // h_cont_{t-1} = h_{t-1} if cont_t == 1
77 // 0 otherwise
78 // calculate hidden
79 // gate_input_t := W_hc * h_conted_{t-1} + W_xc * x_t + b_c
80
81 int ti = reverse ? T - 1 - t : t;
82
83 const float* x = bottom_blob.row(ti);
84 for (int q = 0; q < num_output; q++)
85 {
86 const float* bias_c_I = bias_c.row(0);
87 const float* bias_c_F = bias_c.row(1);
88 const float* bias_c_O = bias_c.row(2);
89 const float* bias_c_G = bias_c.row(3);
90
91 float* gates_data = gates.row(q);
92
93 // gate I F O G
94 const float* weight_xc_I = weight_xc.row(num_output * 0 + q);
95 const float* weight_xc_F = weight_xc.row(num_output * 1 + q);
96 const float* weight_xc_O = weight_xc.row(num_output * 2 + q);
97 const float* weight_xc_G = weight_xc.row(num_output * 3 + q);
98
99 const float* weight_hc_I = weight_hc.row(num_output * 0 + q);
100 const float* weight_hc_F = weight_hc.row(num_output * 1 + q);
101 const float* weight_hc_O = weight_hc.row(num_output * 2 + q);
102 const float* weight_hc_G = weight_hc.row(num_output * 3 + q);
103
104 float I = bias_c_I[q];
105 float F = bias_c_F[q];
106 float O = bias_c_O[q];
107 float G = bias_c_G[q];
108
109 for (int i = 0; i < size; i++)
110 {
111 float xi = x[i];
112
113 I += weight_xc_I[i] * xi;
114 F += weight_xc_F[i] * xi;
115 O += weight_xc_O[i] * xi;
116 G += weight_xc_G[i] * xi;
117 }
118
119 for (int i = 0; i < num_output; i++)
120 {
121 float h_cont = hidden_state[i];
122
123 I += weight_hc_I[i] * h_cont;
124 F += weight_hc_F[i] * h_cont;
125 O += weight_hc_O[i] * h_cont;
126 G += weight_hc_G[i] * h_cont;
127 }
128
129 gates_data[0] = I;
130 gates_data[1] = F;
131 gates_data[2] = O;
132 gates_data[3] = G;
133 }
134
135 // lstm unit
136 // sigmoid(I)
137 // sigmoid(F)
138 // sigmoid(O)
139 // tanh(G)
140 // c_t := f_t .* c_{t-1} + i_t .* g_t
141 // h_t := o_t .* tanh[c_t]
142 float* output_data = top_blob.row(ti);
143 for (int q = 0; q < num_output; q++)
144 {
145 const float* gates_data = gates.row(q);
146
147 float I = gates_data[0];
148 float F = gates_data[1];
149 float O = gates_data[2];
150 float G = gates_data[3];
151
152 I = 1.f / (1.f + exp(-I));
153 F = 1.f / (1.f + exp(-F));
154 O = 1.f / (1.f + exp(-O));
155 G = tanh(G);
156
157 float cell2 = F * cell_state[q] + I * G;
158 float H = O * tanh(cell2);
159 cell_state[q] = cell2;
160 hidden_state[q] = H;
161 output_data[q] = H;
162 }
163 }
164
165 return 0;
166 }
167
forward(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const168 int LSTM::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
169 {
170 int T = bottom_blob.h;
171
172 int num_directions = direction == 2 ? 2 : 1;
173
174 // initial hidden state
175 Mat hidden(num_output, 4u, opt.workspace_allocator);
176 if (hidden.empty())
177 return -100;
178 hidden.fill(0.f);
179
180 Mat cell(num_output, 4u, opt.workspace_allocator);
181 if (cell.empty())
182 return -100;
183 cell.fill(0.f);
184
185 top_blob.create(num_output * num_directions, T, 4u, opt.blob_allocator);
186 if (top_blob.empty())
187 return -100;
188
189 // Uni directional
190 if (direction == 0 || direction == 1)
191 {
192 int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, cell, opt);
193 if (ret != 0)
194 return ret;
195 }
196
197 if (direction == 2)
198 {
199 Mat top_blob_forward(num_output, T, 4u, opt.workspace_allocator);
200 if (top_blob_forward.empty())
201 return -100;
202
203 Mat top_blob_reverse(num_output, T, 4u, opt.workspace_allocator);
204 if (top_blob_reverse.empty())
205 return -100;
206
207 int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, cell, opt);
208 if (ret0 != 0)
209 return ret0;
210
211 hidden.fill(0.0f);
212 cell.fill(0.0f);
213
214 int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden, cell, opt);
215 if (ret1 != 0)
216 return ret1;
217
218 // concat w
219 for (int i = 0; i < T; i++)
220 {
221 const float* pf = top_blob_forward.row(i);
222 const float* pr = top_blob_reverse.row(i);
223 float* ptr = top_blob.row(i);
224
225 memcpy(ptr, pf, num_output * sizeof(float));
226 memcpy(ptr + num_output, pr, num_output * sizeof(float));
227 }
228 }
229
230 return 0;
231 }
232
forward(const std::vector<Mat> & bottom_blobs,std::vector<Mat> & top_blobs,const Option & opt) const233 int LSTM::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
234 {
235 if (bottom_blobs.size() != 3 || top_blobs.size() != 3)
236 {
237 return forward(bottom_blobs[0], top_blobs[0], opt);
238 }
239 const Mat& bottom_blob = bottom_blobs[0];
240 int T = bottom_blob.h;
241 Mat& top_blob = top_blobs[0];
242 Mat& hidden_state = top_blobs[1];
243 Mat& cell_state = top_blobs[2];
244
245 //Copy previous states
246 hidden_state = bottom_blobs[1].clone(opt.blob_allocator);
247 cell_state = bottom_blobs[2].clone(opt.blob_allocator);
248
249 top_blob.create(num_output, T, 4u, opt.blob_allocator);
250 if (top_blob.empty())
251 return -100;
252
253 // Uni directional
254 if (direction == 0 || direction == 1)
255 {
256 int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden_state, cell_state, opt);
257 if (ret != 0)
258 return ret;
259 }
260
261 return 0;
262 }
263
264 } // namespace ncnn
265