1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
15 #include "modelbin.h"
16 
17 #include "datareader.h"
18 
19 #include <string.h>
20 
21 namespace ncnn {
22 
ModelBin()23 ModelBin::ModelBin()
24 {
25 }
26 
~ModelBin()27 ModelBin::~ModelBin()
28 {
29 }
30 
load(int w,int h,int type) const31 Mat ModelBin::load(int w, int h, int type) const
32 {
33     Mat m = load(w * h, type);
34     if (m.empty())
35         return m;
36 
37     return m.reshape(w, h);
38 }
39 
load(int w,int h,int c,int type) const40 Mat ModelBin::load(int w, int h, int c, int type) const
41 {
42     Mat m = load(w * h * c, type);
43     if (m.empty())
44         return m;
45 
46     return m.reshape(w, h, c);
47 }
48 
49 class ModelBinFromDataReaderPrivate
50 {
51 public:
ModelBinFromDataReaderPrivate(const DataReader & _dr)52     ModelBinFromDataReaderPrivate(const DataReader& _dr)
53         : dr(_dr)
54     {
55     }
56     const DataReader& dr;
57 };
58 
ModelBinFromDataReader(const DataReader & _dr)59 ModelBinFromDataReader::ModelBinFromDataReader(const DataReader& _dr)
60     : ModelBin(), d(new ModelBinFromDataReaderPrivate(_dr))
61 {
62 }
63 
~ModelBinFromDataReader()64 ModelBinFromDataReader::~ModelBinFromDataReader()
65 {
66     delete d;
67 }
68 
ModelBinFromDataReader(const ModelBinFromDataReader &)69 ModelBinFromDataReader::ModelBinFromDataReader(const ModelBinFromDataReader&)
70     : d(0)
71 {
72 }
73 
operator =(const ModelBinFromDataReader &)74 ModelBinFromDataReader& ModelBinFromDataReader::operator=(const ModelBinFromDataReader&)
75 {
76     return *this;
77 }
78 
load(int w,int type) const79 Mat ModelBinFromDataReader::load(int w, int type) const
80 {
81     if (type == 0)
82     {
83         size_t nread;
84 
85         union
86         {
87             struct
88             {
89                 unsigned char f0;
90                 unsigned char f1;
91                 unsigned char f2;
92                 unsigned char f3;
93             };
94             unsigned int tag;
95         } flag_struct;
96 
97         nread = d->dr.read(&flag_struct, sizeof(flag_struct));
98         if (nread != sizeof(flag_struct))
99         {
100             NCNN_LOGE("ModelBin read flag_struct failed %zd", nread);
101             return Mat();
102         }
103 
104         unsigned int flag = flag_struct.f0 + flag_struct.f1 + flag_struct.f2 + flag_struct.f3;
105 
106         if (flag_struct.tag == 0x01306B47)
107         {
108             // half-precision data
109             size_t align_data_size = alignSize(w * sizeof(unsigned short), 4);
110             std::vector<unsigned short> float16_weights;
111             float16_weights.resize(align_data_size);
112             nread = d->dr.read(float16_weights.data(), align_data_size);
113             if (nread != align_data_size)
114             {
115                 NCNN_LOGE("ModelBin read float16_weights failed %zd", nread);
116                 return Mat();
117             }
118 
119             return Mat::from_float16(float16_weights.data(), w);
120         }
121         else if (flag_struct.tag == 0x000D4B38)
122         {
123             // int8 data
124             size_t align_data_size = alignSize(w, 4);
125             std::vector<signed char> int8_weights;
126             int8_weights.resize(align_data_size);
127             nread = d->dr.read(int8_weights.data(), align_data_size);
128             if (nread != align_data_size)
129             {
130                 NCNN_LOGE("ModelBin read int8_weights failed %zd", nread);
131                 return Mat();
132             }
133 
134             Mat m(w, (size_t)1u);
135             if (m.empty())
136                 return m;
137 
138             memcpy(m.data, int8_weights.data(), w);
139 
140             return m;
141         }
142         else if (flag_struct.tag == 0x0002C056)
143         {
144             Mat m(w);
145             if (m.empty())
146                 return m;
147 
148             // raw data with extra scaling
149             nread = d->dr.read(m, w * sizeof(float));
150             if (nread != w * sizeof(float))
151             {
152                 NCNN_LOGE("ModelBin read weight_data failed %zd", nread);
153                 return Mat();
154             }
155 
156             return m;
157         }
158 
159         Mat m(w);
160         if (m.empty())
161             return m;
162 
163         if (flag != 0)
164         {
165             // quantized data
166             float quantization_value[256];
167             nread = d->dr.read(quantization_value, 256 * sizeof(float));
168             if (nread != 256 * sizeof(float))
169             {
170                 NCNN_LOGE("ModelBin read quantization_value failed %zd", nread);
171                 return Mat();
172             }
173 
174             size_t align_weight_data_size = alignSize(w * sizeof(unsigned char), 4);
175             std::vector<unsigned char> index_array;
176             index_array.resize(align_weight_data_size);
177             nread = d->dr.read(index_array.data(), align_weight_data_size);
178             if (nread != align_weight_data_size)
179             {
180                 NCNN_LOGE("ModelBin read index_array failed %zd", nread);
181                 return Mat();
182             }
183 
184             float* ptr = m;
185             for (int i = 0; i < w; i++)
186             {
187                 ptr[i] = quantization_value[index_array[i]];
188             }
189         }
190         else if (flag_struct.f0 == 0)
191         {
192             // raw data
193             nread = d->dr.read(m, w * sizeof(float));
194             if (nread != w * sizeof(float))
195             {
196                 NCNN_LOGE("ModelBin read weight_data failed %zd", nread);
197                 return Mat();
198             }
199         }
200 
201         return m;
202     }
203     else if (type == 1)
204     {
205         Mat m(w);
206         if (m.empty())
207             return m;
208 
209         // raw data
210         size_t nread = d->dr.read(m, w * sizeof(float));
211         if (nread != w * sizeof(float))
212         {
213             NCNN_LOGE("ModelBin read weight_data failed %zd", nread);
214             return Mat();
215         }
216 
217         return m;
218     }
219     else
220     {
221         NCNN_LOGE("ModelBin load type %d not implemented", type);
222         return Mat();
223     }
224 
225     return Mat();
226 }
227 
228 class ModelBinFromMatArrayPrivate
229 {
230 public:
ModelBinFromMatArrayPrivate(const Mat * _weights)231     ModelBinFromMatArrayPrivate(const Mat* _weights)
232         : weights(_weights)
233     {
234     }
235     mutable const Mat* weights;
236 };
237 
ModelBinFromMatArray(const Mat * _weights)238 ModelBinFromMatArray::ModelBinFromMatArray(const Mat* _weights)
239     : ModelBin(), d(new ModelBinFromMatArrayPrivate(_weights))
240 {
241 }
242 
~ModelBinFromMatArray()243 ModelBinFromMatArray::~ModelBinFromMatArray()
244 {
245     delete d;
246 }
247 
ModelBinFromMatArray(const ModelBinFromMatArray &)248 ModelBinFromMatArray::ModelBinFromMatArray(const ModelBinFromMatArray&)
249     : d(0)
250 {
251 }
252 
operator =(const ModelBinFromMatArray &)253 ModelBinFromMatArray& ModelBinFromMatArray::operator=(const ModelBinFromMatArray&)
254 {
255     return *this;
256 }
257 
load(int,int) const258 Mat ModelBinFromMatArray::load(int /*w*/, int /*type*/) const
259 {
260     if (!d->weights)
261         return Mat();
262 
263     Mat m = d->weights[0];
264     d->weights++;
265     return m;
266 }
267 
268 } // namespace ncnn
269