1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
15 #include "cast.h"
16
17 namespace ncnn {
18
Cast()19 Cast::Cast()
20 {
21 one_blob_only = true;
22 support_inplace = false;
23 support_packing = true;
24 }
25
load_param(const ParamDict & pd)26 int Cast::load_param(const ParamDict& pd)
27 {
28 type_from = pd.get(0, 0);
29 type_to = pd.get(1, 0);
30
31 return 0;
32 }
33
34 // round to nearest
float32_to_int8(float value)35 signed char float32_to_int8(float value)
36 {
37 float tmp;
38 if (value >= 0.f)
39 tmp = value + 0.5f;
40 else
41 tmp = value - 0.5f;
42
43 if (tmp > 127)
44 return 127;
45 if (tmp < -128)
46 return -128;
47
48 return static_cast<signed char>(tmp);
49 }
50
forward(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const51 int Cast::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
52 {
53 if (type_from == type_to)
54 {
55 top_blob = bottom_blob;
56 return 0;
57 }
58
59 int w = bottom_blob.w;
60 int h = bottom_blob.h;
61 int channels = bottom_blob.c;
62 int dims = bottom_blob.dims;
63 size_t elemsize = bottom_blob.elemsize;
64 int elempack = bottom_blob.elempack;
65
66 size_t out_elemsize = elemsize;
67 if (type_to == 1)
68 {
69 // float32
70 out_elemsize = 4 * elempack;
71 }
72 else if (type_to == 2)
73 {
74 // float16
75 out_elemsize = 2 * elempack;
76 }
77 else if (type_to == 3)
78 {
79 // int8
80 out_elemsize = elempack;
81 }
82 else if (type_to == 4)
83 {
84 // bfloat16
85 out_elemsize = 2 * elempack;
86 }
87
88 if (dims == 1)
89 {
90 top_blob.create(w, out_elemsize, elempack, opt.blob_allocator);
91 }
92 else if (dims == 2)
93 {
94 top_blob.create(w, h, out_elemsize, elempack, opt.blob_allocator);
95 }
96 else if (dims == 3)
97 {
98 top_blob.create(w, h, channels, out_elemsize, elempack, opt.blob_allocator);
99 }
100 if (top_blob.empty())
101 return -100;
102
103 int size = w * h * elempack;
104
105 if (type_from == 1 && type_to == 2)
106 {
107 #pragma omp parallel for num_threads(opt.num_threads)
108 for (int q = 0; q < channels; q++)
109 {
110 const float* ptr = bottom_blob.channel(q);
111 unsigned short* outptr = top_blob.channel(q);
112
113 for (int i = 0; i < size; i++)
114 {
115 outptr[i] = float32_to_float16(ptr[i]);
116 }
117 }
118 }
119
120 if (type_from == 2 && type_to == 1)
121 {
122 #pragma omp parallel for num_threads(opt.num_threads)
123 for (int q = 0; q < channels; q++)
124 {
125 const unsigned short* ptr = bottom_blob.channel(q);
126 float* outptr = top_blob.channel(q);
127
128 for (int i = 0; i < size; i++)
129 {
130 outptr[i] = float16_to_float32(ptr[i]);
131 }
132 }
133 }
134
135 if (type_from == 3 && type_to == 1)
136 {
137 #pragma omp parallel for num_threads(opt.num_threads)
138 for (int q = 0; q < channels; q++)
139 {
140 const signed char* ptr = bottom_blob.channel(q);
141 float* outptr = top_blob.channel(q);
142
143 for (int i = 0; i < size; i++)
144 {
145 outptr[i] = (float)ptr[i];
146 }
147 }
148 }
149
150 if (type_from == 1 && type_to == 4)
151 {
152 #pragma omp parallel for num_threads(opt.num_threads)
153 for (int q = 0; q < channels; q++)
154 {
155 const float* ptr = bottom_blob.channel(q);
156 unsigned short* outptr = top_blob.channel(q);
157
158 for (int i = 0; i < size; i++)
159 {
160 outptr[i] = float32_to_bfloat16(ptr[i]);
161 }
162 }
163 }
164
165 if (type_from == 4 && type_to == 1)
166 {
167 #pragma omp parallel for num_threads(opt.num_threads)
168 for (int q = 0; q < channels; q++)
169 {
170 const unsigned short* ptr = bottom_blob.channel(q);
171 float* outptr = top_blob.channel(q);
172
173 for (int i = 0; i < size; i++)
174 {
175 outptr[i] = bfloat16_to_float32(ptr[i]);
176 }
177 }
178 }
179
180 // TODO more cast type
181
182 return 0;
183 }
184
185 } // namespace ncnn
186