1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
15 #include "cast.h"
16 
17 namespace ncnn {
18 
Cast()19 Cast::Cast()
20 {
21     one_blob_only = true;
22     support_inplace = false;
23     support_packing = true;
24 }
25 
load_param(const ParamDict & pd)26 int Cast::load_param(const ParamDict& pd)
27 {
28     type_from = pd.get(0, 0);
29     type_to = pd.get(1, 0);
30 
31     return 0;
32 }
33 
34 // round to nearest
float32_to_int8(float value)35 signed char float32_to_int8(float value)
36 {
37     float tmp;
38     if (value >= 0.f)
39         tmp = value + 0.5f;
40     else
41         tmp = value - 0.5f;
42 
43     if (tmp > 127)
44         return 127;
45     if (tmp < -128)
46         return -128;
47 
48     return static_cast<signed char>(tmp);
49 }
50 
forward(const Mat & bottom_blob,Mat & top_blob,const Option & opt) const51 int Cast::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
52 {
53     if (type_from == type_to)
54     {
55         top_blob = bottom_blob;
56         return 0;
57     }
58 
59     int w = bottom_blob.w;
60     int h = bottom_blob.h;
61     int channels = bottom_blob.c;
62     int dims = bottom_blob.dims;
63     size_t elemsize = bottom_blob.elemsize;
64     int elempack = bottom_blob.elempack;
65 
66     size_t out_elemsize = elemsize;
67     if (type_to == 1)
68     {
69         // float32
70         out_elemsize = 4 * elempack;
71     }
72     else if (type_to == 2)
73     {
74         // float16
75         out_elemsize = 2 * elempack;
76     }
77     else if (type_to == 3)
78     {
79         // int8
80         out_elemsize = elempack;
81     }
82     else if (type_to == 4)
83     {
84         // bfloat16
85         out_elemsize = 2 * elempack;
86     }
87 
88     if (dims == 1)
89     {
90         top_blob.create(w, out_elemsize, elempack, opt.blob_allocator);
91     }
92     else if (dims == 2)
93     {
94         top_blob.create(w, h, out_elemsize, elempack, opt.blob_allocator);
95     }
96     else if (dims == 3)
97     {
98         top_blob.create(w, h, channels, out_elemsize, elempack, opt.blob_allocator);
99     }
100     if (top_blob.empty())
101         return -100;
102 
103     int size = w * h * elempack;
104 
105     if (type_from == 1 && type_to == 2)
106     {
107         #pragma omp parallel for num_threads(opt.num_threads)
108         for (int q = 0; q < channels; q++)
109         {
110             const float* ptr = bottom_blob.channel(q);
111             unsigned short* outptr = top_blob.channel(q);
112 
113             for (int i = 0; i < size; i++)
114             {
115                 outptr[i] = float32_to_float16(ptr[i]);
116             }
117         }
118     }
119 
120     if (type_from == 2 && type_to == 1)
121     {
122         #pragma omp parallel for num_threads(opt.num_threads)
123         for (int q = 0; q < channels; q++)
124         {
125             const unsigned short* ptr = bottom_blob.channel(q);
126             float* outptr = top_blob.channel(q);
127 
128             for (int i = 0; i < size; i++)
129             {
130                 outptr[i] = float16_to_float32(ptr[i]);
131             }
132         }
133     }
134 
135     if (type_from == 3 && type_to == 1)
136     {
137         #pragma omp parallel for num_threads(opt.num_threads)
138         for (int q = 0; q < channels; q++)
139         {
140             const signed char* ptr = bottom_blob.channel(q);
141             float* outptr = top_blob.channel(q);
142 
143             for (int i = 0; i < size; i++)
144             {
145                 outptr[i] = (float)ptr[i];
146             }
147         }
148     }
149 
150     if (type_from == 1 && type_to == 4)
151     {
152         #pragma omp parallel for num_threads(opt.num_threads)
153         for (int q = 0; q < channels; q++)
154         {
155             const float* ptr = bottom_blob.channel(q);
156             unsigned short* outptr = top_blob.channel(q);
157 
158             for (int i = 0; i < size; i++)
159             {
160                 outptr[i] = float32_to_bfloat16(ptr[i]);
161             }
162         }
163     }
164 
165     if (type_from == 4 && type_to == 1)
166     {
167         #pragma omp parallel for num_threads(opt.num_threads)
168         for (int q = 0; q < channels; q++)
169         {
170             const unsigned short* ptr = bottom_blob.channel(q);
171             float* outptr = top_blob.channel(q);
172 
173             for (int i = 0; i < size; i++)
174             {
175                 outptr[i] = bfloat16_to_float32(ptr[i]);
176             }
177         }
178     }
179 
180     // TODO more cast type
181 
182     return 0;
183 }
184 
185 } // namespace ncnn
186