1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
15 #include "paramdict.h"
16 
17 #include "datareader.h"
18 #include "mat.h"
19 #include "platform.h"
20 
21 #include <ctype.h>
22 
23 #if NCNN_STDIO
24 #include <stdio.h>
25 #endif
26 
27 namespace ncnn {
28 
29 class ParamDictPrivate
30 {
31 public:
32     struct
33     {
34         // 0 = null
35         // 1 = int/float
36         // 2 = int
37         // 3 = float
38         // 4 = array of int/float
39         // 5 = array of int
40         // 6 = array of float
41         int type;
42         union
43         {
44             int i;
45             float f;
46         };
47         Mat v;
48     } params[NCNN_MAX_PARAM_COUNT];
49 };
50 
ParamDict()51 ParamDict::ParamDict()
52     : d(new ParamDictPrivate)
53 {
54     clear();
55 }
56 
~ParamDict()57 ParamDict::~ParamDict()
58 {
59     delete d;
60 }
61 
ParamDict(const ParamDict & rhs)62 ParamDict::ParamDict(const ParamDict& rhs)
63     : d(new ParamDictPrivate)
64 {
65     for (int i = 0; i < NCNN_MAX_PARAM_COUNT; i++)
66     {
67         int type = rhs.d->params[i].type;
68         d->params[i].type = type;
69         if (type == 1 || type == 2 || type == 3)
70         {
71             d->params[i].i = rhs.d->params[i].i;
72         }
73         else // if (type == 4 || type == 5 || type == 6)
74         {
75             d->params[i].v = rhs.d->params[i].v;
76         }
77     }
78 }
79 
operator =(const ParamDict & rhs)80 ParamDict& ParamDict::operator=(const ParamDict& rhs)
81 {
82     if (this == &rhs)
83         return *this;
84 
85     for (int i = 0; i < NCNN_MAX_PARAM_COUNT; i++)
86     {
87         int type = rhs.d->params[i].type;
88         d->params[i].type = type;
89         if (type == 1 || type == 2 || type == 3)
90         {
91             d->params[i].i = rhs.d->params[i].i;
92         }
93         else // if (type == 4 || type == 5 || type == 6)
94         {
95             d->params[i].v = rhs.d->params[i].v;
96         }
97     }
98 
99     return *this;
100 }
101 
type(int id) const102 int ParamDict::type(int id) const
103 {
104     return d->params[id].type;
105 }
106 
107 // TODO strict type check
get(int id,int def) const108 int ParamDict::get(int id, int def) const
109 {
110     return d->params[id].type ? d->params[id].i : def;
111 }
112 
get(int id,float def) const113 float ParamDict::get(int id, float def) const
114 {
115     return d->params[id].type ? d->params[id].f : def;
116 }
117 
get(int id,const Mat & def) const118 Mat ParamDict::get(int id, const Mat& def) const
119 {
120     return d->params[id].type ? d->params[id].v : def;
121 }
122 
set(int id,int i)123 void ParamDict::set(int id, int i)
124 {
125     d->params[id].type = 2;
126     d->params[id].i = i;
127 }
128 
set(int id,float f)129 void ParamDict::set(int id, float f)
130 {
131     d->params[id].type = 3;
132     d->params[id].f = f;
133 }
134 
set(int id,const Mat & v)135 void ParamDict::set(int id, const Mat& v)
136 {
137     d->params[id].type = 4;
138     d->params[id].v = v;
139 }
140 
clear()141 void ParamDict::clear()
142 {
143     for (int i = 0; i < NCNN_MAX_PARAM_COUNT; i++)
144     {
145         d->params[i].type = 0;
146         d->params[i].v = Mat();
147     }
148 }
149 
150 #if NCNN_STRING
vstr_is_float(const char vstr[16])151 static bool vstr_is_float(const char vstr[16])
152 {
153     // look ahead for determine isfloat
154     for (int j = 0; j < 16; j++)
155     {
156         if (vstr[j] == '\0')
157             break;
158 
159         if (vstr[j] == '.' || tolower(vstr[j]) == 'e')
160             return true;
161     }
162 
163     return false;
164 }
165 
vstr_to_float(const char vstr[16])166 static float vstr_to_float(const char vstr[16])
167 {
168     double v = 0.0;
169 
170     const char* p = vstr;
171 
172     // sign
173     bool sign = *p != '-';
174     if (*p == '+' || *p == '-')
175     {
176         p++;
177     }
178 
179     // digits before decimal point or exponent
180     unsigned int v1 = 0;
181     while (isdigit(*p))
182     {
183         v1 = v1 * 10 + (*p - '0');
184         p++;
185     }
186 
187     v = (double)v1;
188 
189     // digits after decimal point
190     if (*p == '.')
191     {
192         p++;
193 
194         unsigned int pow10 = 1;
195         unsigned int v2 = 0;
196 
197         while (isdigit(*p))
198         {
199             v2 = v2 * 10 + (*p - '0');
200             pow10 *= 10;
201             p++;
202         }
203 
204         v += v2 / (double)pow10;
205     }
206 
207     // exponent
208     if (*p == 'e' || *p == 'E')
209     {
210         p++;
211 
212         // sign of exponent
213         bool fact = *p != '-';
214         if (*p == '+' || *p == '-')
215         {
216             p++;
217         }
218 
219         // digits of exponent
220         unsigned int expon = 0;
221         while (isdigit(*p))
222         {
223             expon = expon * 10 + (*p - '0');
224             p++;
225         }
226 
227         double scale = 1.0;
228         while (expon >= 8)
229         {
230             scale *= 1e8;
231             expon -= 8;
232         }
233         while (expon > 0)
234         {
235             scale *= 10.0;
236             expon -= 1;
237         }
238 
239         v = fact ? v * scale : v / scale;
240     }
241 
242     //     fprintf(stderr, "v = %f\n", v);
243     return sign ? (float)v : (float)-v;
244 }
245 
load_param(const DataReader & dr)246 int ParamDict::load_param(const DataReader& dr)
247 {
248     clear();
249 
250     //     0=100 1=1.250000 -23303=5,0.1,0.2,0.4,0.8,1.0
251 
252     // parse each key=value pair
253     int id = 0;
254     while (dr.scan("%d=", &id) == 1)
255     {
256         bool is_array = id <= -23300;
257         if (is_array)
258         {
259             id = -id - 23300;
260         }
261 
262         if (id >= NCNN_MAX_PARAM_COUNT)
263         {
264             NCNN_LOGE("id < NCNN_MAX_PARAM_COUNT failed (id=%d, NCNN_MAX_PARAM_COUNT=%d)", id, NCNN_MAX_PARAM_COUNT);
265             return -1;
266         }
267 
268         if (is_array)
269         {
270             int len = 0;
271             int nscan = dr.scan("%d", &len);
272             if (nscan != 1)
273             {
274                 NCNN_LOGE("ParamDict read array length failed");
275                 return -1;
276             }
277 
278             d->params[id].v.create(len);
279 
280             for (int j = 0; j < len; j++)
281             {
282                 char vstr[16];
283                 nscan = dr.scan(",%15[^,\n ]", vstr);
284                 if (nscan != 1)
285                 {
286                     NCNN_LOGE("ParamDict read array element failed");
287                     return -1;
288                 }
289 
290                 bool is_float = vstr_is_float(vstr);
291 
292                 if (is_float)
293                 {
294                     float* ptr = d->params[id].v;
295                     ptr[j] = vstr_to_float(vstr);
296                 }
297                 else
298                 {
299                     int* ptr = d->params[id].v;
300                     nscan = sscanf(vstr, "%d", &ptr[j]);
301                     if (nscan != 1)
302                     {
303                         NCNN_LOGE("ParamDict parse array element failed");
304                         return -1;
305                     }
306                 }
307 
308                 d->params[id].type = is_float ? 6 : 5;
309             }
310         }
311         else
312         {
313             char vstr[16];
314             int nscan = dr.scan("%15s", vstr);
315             if (nscan != 1)
316             {
317                 NCNN_LOGE("ParamDict read value failed");
318                 return -1;
319             }
320 
321             bool is_float = vstr_is_float(vstr);
322 
323             if (is_float)
324             {
325                 d->params[id].f = vstr_to_float(vstr);
326             }
327             else
328             {
329                 nscan = sscanf(vstr, "%d", &d->params[id].i);
330                 if (nscan != 1)
331                 {
332                     NCNN_LOGE("ParamDict parse value failed");
333                     return -1;
334                 }
335             }
336 
337             d->params[id].type = is_float ? 3 : 2;
338         }
339     }
340 
341     return 0;
342 }
343 #endif // NCNN_STRING
344 
load_param_bin(const DataReader & dr)345 int ParamDict::load_param_bin(const DataReader& dr)
346 {
347     clear();
348 
349     //     binary 0
350     //     binary 100
351     //     binary 1
352     //     binary 1.250000
353     //     binary 3 | array_bit
354     //     binary 5
355     //     binary 0.1
356     //     binary 0.2
357     //     binary 0.4
358     //     binary 0.8
359     //     binary 1.0
360     //     binary -233(EOP)
361 
362     int id = 0;
363     size_t nread;
364     nread = dr.read(&id, sizeof(int));
365     if (nread != sizeof(int))
366     {
367         NCNN_LOGE("ParamDict read id failed %zd", nread);
368         return -1;
369     }
370 
371     while (id != -233)
372     {
373         bool is_array = id <= -23300;
374         if (is_array)
375         {
376             id = -id - 23300;
377         }
378 
379         if (id >= NCNN_MAX_PARAM_COUNT)
380         {
381             NCNN_LOGE("id < NCNN_MAX_PARAM_COUNT failed (id=%d, NCNN_MAX_PARAM_COUNT=%d)", id, NCNN_MAX_PARAM_COUNT);
382             return -1;
383         }
384 
385         if (is_array)
386         {
387             int len = 0;
388             nread = dr.read(&len, sizeof(int));
389             if (nread != sizeof(int))
390             {
391                 NCNN_LOGE("ParamDict read array length failed %zd", nread);
392                 return -1;
393             }
394 
395             d->params[id].v.create(len);
396 
397             float* ptr = d->params[id].v;
398             nread = dr.read(ptr, sizeof(float) * len);
399             if (nread != sizeof(float) * len)
400             {
401                 NCNN_LOGE("ParamDict read array element failed %zd", nread);
402                 return -1;
403             }
404 
405             d->params[id].type = 4;
406         }
407         else
408         {
409             nread = dr.read(&d->params[id].f, sizeof(float));
410             if (nread != sizeof(float))
411             {
412                 NCNN_LOGE("ParamDict read value failed %zd", nread);
413                 return -1;
414             }
415 
416             d->params[id].type = 1;
417         }
418 
419         nread = dr.read(&id, sizeof(int));
420         if (nread != sizeof(int))
421         {
422             NCNN_LOGE("ParamDict read EOP failed %zd", nread);
423             return -1;
424         }
425     }
426 
427     return 0;
428 }
429 
430 } // namespace ncnn
431