1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
15 #include "lstm.h"
16 
17 #include <math.h>
18 
19 namespace ncnn {
20 
LSTM()21 LSTM::LSTM()
22 {
23     one_blob_only = false;
24     support_inplace = false;
25 }
26 
load_param(const ParamDict & pd)27 int LSTM::load_param(const ParamDict& pd)
28 {
29     num_output = pd.get(0, 0);
30     weight_data_size = pd.get(1, 0);
31     direction = pd.get(2, 0);
32     if (direction == 2)
33         one_blob_only = true;
34     return 0;
35 }
36 
load_model(const ModelBin & mb)37 int LSTM::load_model(const ModelBin& mb)
38 {
39     int num_directions = direction == 2 ? 2 : 1;
40 
41     int size = weight_data_size / num_directions / num_output / 4;
42 
43     // raw weight data
44     weight_xc_data = mb.load(size, num_output * 4, num_directions, 0);
45     if (weight_xc_data.empty())
46         return -100;
47 
48     bias_c_data = mb.load(num_output, 4, num_directions, 0);
49     if (bias_c_data.empty())
50         return -100;
51 
52     weight_hc_data = mb.load(num_output, num_output * 4, num_directions, 0);
53     if (weight_hc_data.empty())
54         return -100;
55 
56     return 0;
57 }
58 
lstm(const Mat & bottom_blob,Mat & top_blob,int reverse,const Mat & weight_xc,const Mat & bias_c,const Mat & weight_hc,Mat & hidden_state,Mat & cell_state,const Option & opt)59 static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, Mat& cell_state, const Option& opt)
60 {
61     int size = bottom_blob.w;
62     int T = bottom_blob.h;
63 
64     int num_output = top_blob.w;
65 
66     // 4 x num_output
67     Mat gates(4, num_output, 4u, opt.workspace_allocator);
68     if (gates.empty())
69         return -100;
70 
71     // unroll
72     for (int t = 0; t < T; t++)
73     {
74         // clip hidden by continuation indicator
75         // h_cont_{t-1} = cont_t * h_{t-1}
76         // h_cont_{t-1} = h_{t-1} if cont_t == 1
77         //                0       otherwise
78         // calculate hidden
79         // gate_input_t := W_hc * h_conted_{t-1} + W_xc * x_t + b_c
80 
81         int ti = reverse ? T - 1 - t : t;
82 
83         const float* x = bottom_blob.row(ti);
84         for (int q = 0; q < num_output; q++)
85         {
86             const float* bias_c_I = bias_c.row(0);
87             const float* bias_c_F = bias_c.row(1);
88             const float* bias_c_O = bias_c.row(2);
89             const float* bias_c_G = bias_c.row(3);
90 
91             float* gates_data = gates.row(q);
92 
93             // gate I F O G
94             const float* weight_xc_I = weight_xc.row(num_output * 0 + q);
95             const float* weight_xc_F = weight_xc.row(num_output * 1 + q);
96             const float* weight_xc_O = weight_xc.row(num_output * 2 + q);
97             const float* weight_xc_G = weight_xc.row(num_output * 3 + q);
98 
99             const float* weight_hc_I = weight_hc.row(num_output * 0 + q);
100             const float* weight_hc_F = weight_hc.row(num_output * 1 + q);
101             const float* weight_hc_O = weight_hc.row(num_output * 2 + q);
102             const float* weight_hc_G = weight_hc.row(num_output * 3 + q);
103 
104             float I = bias_c_I[q];
105             float F = bias_c_F[q];
106             float O = bias_c_O[q];
107             float G = bias_c_G[q];
108 
109             for (int i = 0; i < size; i++)
110             {
111                 float xi = x[i];
112 
113                 I += weight_xc_I[i] * xi;
114                 F += weight_xc_F[i] * xi;
115                 O += weight_xc_O[i] * xi;
116                 G += weight_xc_G[i] * xi;
117             }
118 
119             for (int i = 0; i < num_output; i++)
120             {
121                 float h_cont = hidden_state[i];
122 
123                 I += weight_hc_I[i] * h_cont;
124                 F += weight_hc_F[i] * h_cont;
125                 O += weight_hc_O[i] * h_cont;
126                 G += weight_hc_G[i] * h_cont;
127             }
128 
129             gates_data[0] = I;
130             gates_data[1] = F;
131             gates_data[2] = O;
132             gates_data[3] = G;
133         }
134 
135         // lstm unit
136         // sigmoid(I)
137         // sigmoid(F)
138         // sigmoid(O)
139         // tanh(G)
140         // c_t := f_t .* c_{t-1} + i_t .* g_t
141         // h_t := o_t .* tanh[c_t]
142         float* output_data = top_blob.row(ti);
143         for (int q = 0; q < num_output; q++)
144         {
145             const float* gates_data = gates.row(q);
146 
147             float I = gates_data[0];
148             float F = gates_data[1];
149             float O = gates_data[2];
150             float G = gates_data[3];
151 
152             I = 1.f / (1.f + exp(-I));
153             F = 1.f / (1.f + exp(-F));
154             O = 1.f / (1.f + exp(-O));
155             G = tanh(G);
156 
157             float cell2 = F * cell_state[q] + I * G;
158             float H = O * tanh(cell2);
159             cell_state[q] = cell2;
160             hidden_state[q] = H;
161             output_data[q] = H;
162         }
163     }
164 
165     return 0;
166 }
167 
forward(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const168 int LSTM::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
169 {
170     int T = bottom_blob.h;
171 
172     int num_directions = direction == 2 ? 2 : 1;
173 
174     // initial hidden state
175     Mat hidden(num_output, 4u, opt.workspace_allocator);
176     if (hidden.empty())
177         return -100;
178     hidden.fill(0.f);
179 
180     Mat cell(num_output, 4u, opt.workspace_allocator);
181     if (cell.empty())
182         return -100;
183     cell.fill(0.f);
184 
185     top_blob.create(num_output * num_directions, T, 4u, opt.blob_allocator);
186     if (top_blob.empty())
187         return -100;
188 
189     // Uni directional
190     if (direction == 0 || direction == 1)
191     {
192         int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, cell, opt);
193         if (ret != 0)
194             return ret;
195     }
196 
197     if (direction == 2)
198     {
199         Mat top_blob_forward(num_output, T, 4u, opt.workspace_allocator);
200         if (top_blob_forward.empty())
201             return -100;
202 
203         Mat top_blob_reverse(num_output, T, 4u, opt.workspace_allocator);
204         if (top_blob_reverse.empty())
205             return -100;
206 
207         int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, cell, opt);
208         if (ret0 != 0)
209             return ret0;
210 
211         hidden.fill(0.0f);
212         cell.fill(0.0f);
213 
214         int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden, cell, opt);
215         if (ret1 != 0)
216             return ret1;
217 
218         // concat w
219         for (int i = 0; i < T; i++)
220         {
221             const float* pf = top_blob_forward.row(i);
222             const float* pr = top_blob_reverse.row(i);
223             float* ptr = top_blob.row(i);
224 
225             memcpy(ptr, pf, num_output * sizeof(float));
226             memcpy(ptr + num_output, pr, num_output * sizeof(float));
227         }
228     }
229 
230     return 0;
231 }
232 
forward(const std::vector<Mat> & bottom_blobs,std::vector<Mat> & top_blobs,const Option & opt) const233 int LSTM::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
234 {
235     if (bottom_blobs.size() != 3 || top_blobs.size() != 3)
236     {
237         return forward(bottom_blobs[0], top_blobs[0], opt);
238     }
239     const Mat& bottom_blob = bottom_blobs[0];
240     int T = bottom_blob.h;
241     Mat& top_blob = top_blobs[0];
242     Mat& hidden_state = top_blobs[1];
243     Mat& cell_state = top_blobs[2];
244 
245     //Copy previous states
246     hidden_state = bottom_blobs[1].clone(opt.blob_allocator);
247     cell_state = bottom_blobs[2].clone(opt.blob_allocator);
248 
249     top_blob.create(num_output, T, 4u, opt.blob_allocator);
250     if (top_blob.empty())
251         return -100;
252 
253     // Uni directional
254     if (direction == 0 || direction == 1)
255     {
256         int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden_state, cell_state, opt);
257         if (ret != 0)
258             return ret;
259     }
260 
261     return 0;
262 }
263 
264 } // namespace ncnn
265