1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
15 #include "paramdict.h"
16
17 #include "datareader.h"
18 #include "mat.h"
19 #include "platform.h"
20
21 #include <ctype.h>
22
23 #if NCNN_STDIO
24 #include <stdio.h>
25 #endif
26
27 namespace ncnn {
28
29 class ParamDictPrivate
30 {
31 public:
32 struct
33 {
34 // 0 = null
35 // 1 = int/float
36 // 2 = int
37 // 3 = float
38 // 4 = array of int/float
39 // 5 = array of int
40 // 6 = array of float
41 int type;
42 union
43 {
44 int i;
45 float f;
46 };
47 Mat v;
48 } params[NCNN_MAX_PARAM_COUNT];
49 };
50
ParamDict()51 ParamDict::ParamDict()
52 : d(new ParamDictPrivate)
53 {
54 clear();
55 }
56
~ParamDict()57 ParamDict::~ParamDict()
58 {
59 delete d;
60 }
61
ParamDict(const ParamDict & rhs)62 ParamDict::ParamDict(const ParamDict& rhs)
63 : d(new ParamDictPrivate)
64 {
65 for (int i = 0; i < NCNN_MAX_PARAM_COUNT; i++)
66 {
67 int type = rhs.d->params[i].type;
68 d->params[i].type = type;
69 if (type == 1 || type == 2 || type == 3)
70 {
71 d->params[i].i = rhs.d->params[i].i;
72 }
73 else // if (type == 4 || type == 5 || type == 6)
74 {
75 d->params[i].v = rhs.d->params[i].v;
76 }
77 }
78 }
79
operator =(const ParamDict & rhs)80 ParamDict& ParamDict::operator=(const ParamDict& rhs)
81 {
82 if (this == &rhs)
83 return *this;
84
85 for (int i = 0; i < NCNN_MAX_PARAM_COUNT; i++)
86 {
87 int type = rhs.d->params[i].type;
88 d->params[i].type = type;
89 if (type == 1 || type == 2 || type == 3)
90 {
91 d->params[i].i = rhs.d->params[i].i;
92 }
93 else // if (type == 4 || type == 5 || type == 6)
94 {
95 d->params[i].v = rhs.d->params[i].v;
96 }
97 }
98
99 return *this;
100 }
101
type(int id) const102 int ParamDict::type(int id) const
103 {
104 return d->params[id].type;
105 }
106
107 // TODO strict type check
get(int id,int def) const108 int ParamDict::get(int id, int def) const
109 {
110 return d->params[id].type ? d->params[id].i : def;
111 }
112
get(int id,float def) const113 float ParamDict::get(int id, float def) const
114 {
115 return d->params[id].type ? d->params[id].f : def;
116 }
117
get(int id,const Mat & def) const118 Mat ParamDict::get(int id, const Mat& def) const
119 {
120 return d->params[id].type ? d->params[id].v : def;
121 }
122
set(int id,int i)123 void ParamDict::set(int id, int i)
124 {
125 d->params[id].type = 2;
126 d->params[id].i = i;
127 }
128
set(int id,float f)129 void ParamDict::set(int id, float f)
130 {
131 d->params[id].type = 3;
132 d->params[id].f = f;
133 }
134
set(int id,const Mat & v)135 void ParamDict::set(int id, const Mat& v)
136 {
137 d->params[id].type = 4;
138 d->params[id].v = v;
139 }
140
clear()141 void ParamDict::clear()
142 {
143 for (int i = 0; i < NCNN_MAX_PARAM_COUNT; i++)
144 {
145 d->params[i].type = 0;
146 d->params[i].v = Mat();
147 }
148 }
149
150 #if NCNN_STRING
vstr_is_float(const char vstr[16])151 static bool vstr_is_float(const char vstr[16])
152 {
153 // look ahead for determine isfloat
154 for (int j = 0; j < 16; j++)
155 {
156 if (vstr[j] == '\0')
157 break;
158
159 if (vstr[j] == '.' || tolower(vstr[j]) == 'e')
160 return true;
161 }
162
163 return false;
164 }
165
vstr_to_float(const char vstr[16])166 static float vstr_to_float(const char vstr[16])
167 {
168 double v = 0.0;
169
170 const char* p = vstr;
171
172 // sign
173 bool sign = *p != '-';
174 if (*p == '+' || *p == '-')
175 {
176 p++;
177 }
178
179 // digits before decimal point or exponent
180 unsigned int v1 = 0;
181 while (isdigit(*p))
182 {
183 v1 = v1 * 10 + (*p - '0');
184 p++;
185 }
186
187 v = (double)v1;
188
189 // digits after decimal point
190 if (*p == '.')
191 {
192 p++;
193
194 unsigned int pow10 = 1;
195 unsigned int v2 = 0;
196
197 while (isdigit(*p))
198 {
199 v2 = v2 * 10 + (*p - '0');
200 pow10 *= 10;
201 p++;
202 }
203
204 v += v2 / (double)pow10;
205 }
206
207 // exponent
208 if (*p == 'e' || *p == 'E')
209 {
210 p++;
211
212 // sign of exponent
213 bool fact = *p != '-';
214 if (*p == '+' || *p == '-')
215 {
216 p++;
217 }
218
219 // digits of exponent
220 unsigned int expon = 0;
221 while (isdigit(*p))
222 {
223 expon = expon * 10 + (*p - '0');
224 p++;
225 }
226
227 double scale = 1.0;
228 while (expon >= 8)
229 {
230 scale *= 1e8;
231 expon -= 8;
232 }
233 while (expon > 0)
234 {
235 scale *= 10.0;
236 expon -= 1;
237 }
238
239 v = fact ? v * scale : v / scale;
240 }
241
242 // fprintf(stderr, "v = %f\n", v);
243 return sign ? (float)v : (float)-v;
244 }
245
load_param(const DataReader & dr)246 int ParamDict::load_param(const DataReader& dr)
247 {
248 clear();
249
250 // 0=100 1=1.250000 -23303=5,0.1,0.2,0.4,0.8,1.0
251
252 // parse each key=value pair
253 int id = 0;
254 while (dr.scan("%d=", &id) == 1)
255 {
256 bool is_array = id <= -23300;
257 if (is_array)
258 {
259 id = -id - 23300;
260 }
261
262 if (id >= NCNN_MAX_PARAM_COUNT)
263 {
264 NCNN_LOGE("id < NCNN_MAX_PARAM_COUNT failed (id=%d, NCNN_MAX_PARAM_COUNT=%d)", id, NCNN_MAX_PARAM_COUNT);
265 return -1;
266 }
267
268 if (is_array)
269 {
270 int len = 0;
271 int nscan = dr.scan("%d", &len);
272 if (nscan != 1)
273 {
274 NCNN_LOGE("ParamDict read array length failed");
275 return -1;
276 }
277
278 d->params[id].v.create(len);
279
280 for (int j = 0; j < len; j++)
281 {
282 char vstr[16];
283 nscan = dr.scan(",%15[^,\n ]", vstr);
284 if (nscan != 1)
285 {
286 NCNN_LOGE("ParamDict read array element failed");
287 return -1;
288 }
289
290 bool is_float = vstr_is_float(vstr);
291
292 if (is_float)
293 {
294 float* ptr = d->params[id].v;
295 ptr[j] = vstr_to_float(vstr);
296 }
297 else
298 {
299 int* ptr = d->params[id].v;
300 nscan = sscanf(vstr, "%d", &ptr[j]);
301 if (nscan != 1)
302 {
303 NCNN_LOGE("ParamDict parse array element failed");
304 return -1;
305 }
306 }
307
308 d->params[id].type = is_float ? 6 : 5;
309 }
310 }
311 else
312 {
313 char vstr[16];
314 int nscan = dr.scan("%15s", vstr);
315 if (nscan != 1)
316 {
317 NCNN_LOGE("ParamDict read value failed");
318 return -1;
319 }
320
321 bool is_float = vstr_is_float(vstr);
322
323 if (is_float)
324 {
325 d->params[id].f = vstr_to_float(vstr);
326 }
327 else
328 {
329 nscan = sscanf(vstr, "%d", &d->params[id].i);
330 if (nscan != 1)
331 {
332 NCNN_LOGE("ParamDict parse value failed");
333 return -1;
334 }
335 }
336
337 d->params[id].type = is_float ? 3 : 2;
338 }
339 }
340
341 return 0;
342 }
343 #endif // NCNN_STRING
344
load_param_bin(const DataReader & dr)345 int ParamDict::load_param_bin(const DataReader& dr)
346 {
347 clear();
348
349 // binary 0
350 // binary 100
351 // binary 1
352 // binary 1.250000
353 // binary 3 | array_bit
354 // binary 5
355 // binary 0.1
356 // binary 0.2
357 // binary 0.4
358 // binary 0.8
359 // binary 1.0
360 // binary -233(EOP)
361
362 int id = 0;
363 size_t nread;
364 nread = dr.read(&id, sizeof(int));
365 if (nread != sizeof(int))
366 {
367 NCNN_LOGE("ParamDict read id failed %zd", nread);
368 return -1;
369 }
370
371 while (id != -233)
372 {
373 bool is_array = id <= -23300;
374 if (is_array)
375 {
376 id = -id - 23300;
377 }
378
379 if (id >= NCNN_MAX_PARAM_COUNT)
380 {
381 NCNN_LOGE("id < NCNN_MAX_PARAM_COUNT failed (id=%d, NCNN_MAX_PARAM_COUNT=%d)", id, NCNN_MAX_PARAM_COUNT);
382 return -1;
383 }
384
385 if (is_array)
386 {
387 int len = 0;
388 nread = dr.read(&len, sizeof(int));
389 if (nread != sizeof(int))
390 {
391 NCNN_LOGE("ParamDict read array length failed %zd", nread);
392 return -1;
393 }
394
395 d->params[id].v.create(len);
396
397 float* ptr = d->params[id].v;
398 nread = dr.read(ptr, sizeof(float) * len);
399 if (nread != sizeof(float) * len)
400 {
401 NCNN_LOGE("ParamDict read array element failed %zd", nread);
402 return -1;
403 }
404
405 d->params[id].type = 4;
406 }
407 else
408 {
409 nread = dr.read(&d->params[id].f, sizeof(float));
410 if (nread != sizeof(float))
411 {
412 NCNN_LOGE("ParamDict read value failed %zd", nread);
413 return -1;
414 }
415
416 d->params[id].type = 1;
417 }
418
419 nread = dr.read(&id, sizeof(int));
420 if (nread != sizeof(int))
421 {
422 NCNN_LOGE("ParamDict read EOP failed %zd", nread);
423 return -1;
424 }
425 }
426
427 return 0;
428 }
429
430 } // namespace ncnn
431