1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2019 BUG1989. All rights reserved.
4 // Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved.
5 //
6 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
7 // in compliance with the License. You may obtain a copy of the License at
8 //
9 // https://opensource.org/licenses/BSD-3-Clause
10 //
11 // Unless required by applicable law or agreed to in writing, software distributed
12 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
13 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
14 // specific language governing permissions and limitations under the License.
15 
16 #include "dequantize_arm.h"
17 
18 namespace ncnn {
19 
forward_inplace(Mat & bottom_top_blob,const Option & opt) const20 int Dequantize_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
21 {
22     int dims = bottom_top_blob.dims;
23 
24     if (dims == 1)
25     {
26         int w = bottom_top_blob.w;
27 
28         int* intptr = bottom_top_blob;
29         float* ptr = bottom_top_blob;
30 
31         if (bias_term)
32         {
33             #pragma omp parallel for num_threads(opt.num_threads)
34             for (int i = 0; i < w; i++)
35             {
36                 ptr[i] = intptr[i] * scale + bias_data[i];
37             }
38         }
39         else
40         {
41             #pragma omp parallel for num_threads(opt.num_threads)
42             for (int i = 0; i < w; i++)
43             {
44                 ptr[i] = intptr[i] * scale;
45             }
46         }
47     }
48 
49     if (dims == 2)
50     {
51         int w = bottom_top_blob.w;
52         int h = bottom_top_blob.h;
53 
54         if (bias_term)
55         {
56             #pragma omp parallel for num_threads(opt.num_threads)
57             for (int i = 0; i < h; i++)
58             {
59                 const int* intptr = bottom_top_blob.row<const int>(i);
60                 float* ptr = bottom_top_blob.row(i);
61 
62                 float bias = bias_data_size > 1 ? bias_data[i] : bias_data[0];
63 
64                 for (int j = 0; j < w; j++)
65                 {
66                     ptr[j] = intptr[j] * scale + bias;
67                 }
68             }
69         }
70         else
71         {
72             #pragma omp parallel for num_threads(opt.num_threads)
73             for (int i = 0; i < h; i++)
74             {
75                 const int* intptr = bottom_top_blob.row<const int>(i);
76                 float* ptr = bottom_top_blob.row(i);
77 
78                 for (int j = 0; j < w; j++)
79                 {
80                     ptr[j] = intptr[j] * scale;
81                 }
82             }
83         }
84     }
85 
86     if (dims == 3)
87     {
88         int w = bottom_top_blob.w;
89         int h = bottom_top_blob.h;
90         int channels = bottom_top_blob.c;
91         int size = w * h;
92 
93         if (bias_term)
94         {
95             #pragma omp parallel for num_threads(opt.num_threads)
96             for (int q = 0; q < channels; q++)
97             {
98                 int* intptr = bottom_top_blob.channel(q);
99                 float* ptr = bottom_top_blob.channel(q);
100 
101                 float bias = bias_data[q];
102 
103 #if __ARM_NEON
104                 int nn = size >> 3;
105                 int remain = size & 7;
106 #else
107                 int remain = size;
108 #endif // __ARM_NEON
109 
110 #if __ARM_NEON
111 #if __aarch64__
112                 if (nn > 0)
113                 {
114                     asm volatile(
115                         "dup    v2.4s, %w6                   \n" // scale
116                         "dup    v3.4s, %w7                   \n" // bias
117                         "0:                                  \n"
118                         "prfm   pldl1keep, [%1, #128]        \n"
119                         "ld1    {v0.4s, v1.4s}, [%1], #32    \n" // data
120                         // top_s32 -> top_f32
121                         "scvtf  v5.4s, v0.4s                 \n"
122                         "scvtf  v6.4s, v1.4s                 \n"
123                         // top_f32 = top_f32 * scale_out
124                         "fmul   v5.4s, v5.4s, v2.4s          \n"
125                         "fmul   v6.4s, v6.4s, v2.4s          \n"
126                         // top_f32 = top_f32 + bias_tm
127                         "fadd   v5.4s, v5.4s, v3.4s          \n"
128                         "fadd   v6.4s, v6.4s, v3.4s          \n"
129                         // save top_f32
130                         "st1    {v5.4s, v6.4s}, [%2], #32    \n"
131                         "subs   %w0, %w0, #1                 \n"
132                         "bne    0b                           \n"
133                         : "=r"(nn),     // %0
134                         "=r"(intptr), // %1
135                         "=r"(ptr)     // %2
136                         : "0"(nn),
137                         "1"(intptr),
138                         "2"(ptr),
139                         "r"(scale), // %6
140                         "r"(bias)   // %7
141                         : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6");
142                 }
143 #else
144                 if (nn > 0)
145                 {
146                     asm volatile(
147                         "pld        [%1, #256]          \n"
148                         "vld1.s32   {d0-d3}, [%1]!      \n" //q0-q1 data
149                         "vdup.f32   q10, %6             \n" //q10 scale
150                         "vdup.f32   q12, %7             \n" //q12 bias
151 
152                         "0:                             \n"
153                         "vcvt.f32.s32 q0, q0            \n"
154                         "vcvt.f32.s32 q1, q1            \n"
155 
156                         "vmul.f32   q0,q0,q10           \n"
157                         "vmul.f32   q1,q1,q10           \n"
158 
159                         "vadd.f32   q2,q0,q12           \n"
160                         "vadd.f32   q3,q1,q12           \n"
161 
162                         "pld        [%1, #256]          \n"
163                         "vld1.s32   {d0-d3}, [%1]!      \n"
164                         "vst1.f32   {d4-d7}, [%2]!      \n"
165 
166                         "subs       %0, #1              \n"
167                         "bne        0b                  \n"
168 
169                         "sub        %1, #32             \n"
170                         : "=r"(nn),     // %0
171                         "=r"(intptr), // %1
172                         "=r"(ptr)     // %2
173                         : "0"(nn),
174                         "1"(intptr),
175                         "2"(ptr),
176                         "r"(scale), // %6
177                         "r"(bias)   // %7
178                         : "cc", "memory", "q0", "q1", "q2", "q3", "q10", "q12");
179                 }
180 #endif // __aarch64__
181 #endif // __ARM_NEON
182                 for (; remain > 0; remain--)
183                 {
184                     *ptr = *intptr * scale + bias;
185 
186                     intptr++;
187                     ptr++;
188                 }
189             }
190         }
191         else
192         {
193             #pragma omp parallel for num_threads(opt.num_threads)
194             for (int q = 0; q < channels; q++)
195             {
196                 int* intptr = bottom_top_blob.channel(q);
197                 float* ptr = bottom_top_blob.channel(q);
198 
199 #if __ARM_NEON
200                 int nn = size >> 3;
201                 int remain = size & 7;
202 #else
203                 int remain = size;
204 #endif // __ARM_NEON
205 
206 #if __ARM_NEON
207 #if __aarch64__
208                 if (nn > 0)
209                 {
210                     asm volatile(
211                         "dup    v2.4s, %w6                   \n" // scale
212                         "0:                                  \n"
213                         "prfm   pldl1keep, [%1, #128]      \n"
214                         "ld1    {v0.4s, v1.4s}, [%1], #32    \n" // data
215                         // top_s32 -> top_f32
216                         "scvtf  v5.4s, v0.4s                 \n"
217                         "scvtf  v6.4s, v1.4s                 \n"
218                         // top_f32 = top_f32 * scale_out
219                         "fmul   v5.4s, v5.4s, v2.4s          \n"
220                         "fmul   v6.4s, v6.4s, v2.4s          \n"
221                         // save top_f32
222                         "st1    {v5.4s, v6.4s}, [%2], #32    \n"
223                         "subs   %w0, %w0, #1                 \n"
224                         "bne    0b                           \n"
225                         : "=r"(nn),     // %0
226                         "=r"(intptr), // %1
227                         "=r"(ptr)     // %2
228                         : "0"(nn),
229                         "1"(intptr),
230                         "2"(ptr),
231                         "r"(scale) // %6
232                         : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6");
233                 }
234 #else
235                 if (nn > 0)
236                 {
237                     asm volatile(
238                         "pld        [%1, #256]          \n"
239                         "vld1.s32   {d0-d3}, [%1]!      \n" //q0-q1 data
240                         "vdup.f32   q10, %6             \n" //q10 scale
241 
242                         "0:                             \n"
243                         "vcvt.f32.s32 q0, q0            \n"
244                         "vcvt.f32.s32 q1, q1            \n"
245 
246                         "vmul.f32   q2,q0,q10           \n"
247                         "vmul.f32   q3,q1,q10           \n"
248 
249                         "pld        [%1, #256]          \n"
250                         "vld1.s32   {d0-d3}, [%1]!      \n"
251                         "vst1.f32   {d4-d7}, [%2]!      \n"
252 
253                         "subs       %0, #1              \n"
254                         "bne        0b                  \n"
255 
256                         "sub        %1, #32             \n"
257                         : "=r"(nn),     // %0
258                         "=r"(intptr), // %1
259                         "=r"(ptr)     // %2
260                         : "0"(nn),
261                         "1"(intptr),
262                         "2"(ptr),
263                         "r"(scale) // %6
264                         : "cc", "memory", "q0", "q1", "q2", "q3", "q10", "q12");
265                 }
266 #endif // __aarch64__
267 #endif // __ARM_NEON
268                 for (; remain > 0; remain--)
269                 {
270                     *ptr = *intptr * scale;
271 
272                     intptr++;
273                     ptr++;
274                 }
275             }
276         }
277     }
278 
279     return 0;
280 }
281 
282 } // namespace ncnn
283