1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
resize_bilinear_image_pack4_bf16s(const Mat & src,Mat & dst,float * alpha,int * xofs,float * beta,int * yofs)15 static void resize_bilinear_image_pack4_bf16s(const Mat& src, Mat& dst, float* alpha, int* xofs, float* beta, int* yofs)
16 {
17     int w = dst.w;
18     int h = dst.h;
19 
20     // loop body
21     Mat rowsbuf0(w, (size_t)4 * 4u, 4);
22     Mat rowsbuf1(w, (size_t)4 * 4u, 4);
23     float* rows0 = rowsbuf0;
24     float* rows1 = rowsbuf1;
25 
26     int prev_sy1 = -2;
27 
28     for (int dy = 0; dy < h; dy++)
29     {
30         int sy = yofs[dy];
31 
32         if (sy == prev_sy1)
33         {
34             // reuse all rows
35         }
36         else if (sy == prev_sy1 + 1)
37         {
38             // hresize one row
39             float* rows0_old = rows0;
40             rows0 = rows1;
41             rows1 = rows0_old;
42             const unsigned short* S1 = src.row<const unsigned short>(sy + 1);
43 
44             const float* alphap = alpha;
45             float* rows1p = rows1;
46             int dx = 0;
47             for (; dx < w; dx++)
48             {
49                 int sx = xofs[dx] * 4;
50                 const unsigned short* S1p = S1 + sx;
51 
52                 float32x2_t _a01 = vld1_f32(alphap);
53 
54                 float32x4_t _S10 = vcvt_f32_bf16(vld1_u16(S1p));
55                 float32x4_t _S11 = vcvt_f32_bf16(vld1_u16(S1p + 4));
56                 float32x4_t _rows1 = vmulq_lane_f32(_S10, _a01, 0);
57                 _rows1 = vmlaq_lane_f32(_rows1, _S11, _a01, 1);
58                 vst1q_f32(rows1p + dx * 4, _rows1);
59 
60                 alphap += 2;
61             }
62         }
63         else
64         {
65             // hresize two rows
66             const unsigned short* S0 = src.row<const unsigned short>(sy);
67             const unsigned short* S1 = src.row<const unsigned short>(sy + 1);
68 
69             const float* alphap = alpha;
70             float* rows0p = rows0;
71             float* rows1p = rows1;
72             int dx = 0;
73             for (; dx < w; dx++)
74             {
75                 int sx = xofs[dx] * 4;
76                 const unsigned short* S0p = S0 + sx;
77                 const unsigned short* S1p = S1 + sx;
78 
79                 float32x2_t _a01 = vld1_f32(alphap);
80 
81                 float32x4_t _S00 = vcvt_f32_bf16(vld1_u16(S0p));
82                 float32x4_t _S01 = vcvt_f32_bf16(vld1_u16(S0p + 4));
83                 float32x4_t _S10 = vcvt_f32_bf16(vld1_u16(S1p));
84                 float32x4_t _S11 = vcvt_f32_bf16(vld1_u16(S1p + 4));
85                 float32x4_t _rows0 = vmulq_lane_f32(_S00, _a01, 0);
86                 float32x4_t _rows1 = vmulq_lane_f32(_S10, _a01, 0);
87                 _rows0 = vmlaq_lane_f32(_rows0, _S01, _a01, 1);
88                 _rows1 = vmlaq_lane_f32(_rows1, _S11, _a01, 1);
89                 vst1q_f32(rows0p + dx * 4, _rows0);
90                 vst1q_f32(rows1p + dx * 4, _rows1);
91 
92                 alphap += 2;
93             }
94         }
95 
96         prev_sy1 = sy;
97 
98         // vresize
99         float32x2_t _b01 = vld1_f32(beta);
100 
101         float* rows0p = rows0;
102         float* rows1p = rows1;
103         unsigned short* Dp = dst.row<unsigned short>(dy);
104 
105         for (int dx = 0; dx < w; dx++)
106         {
107             float32x4_t _rows0 = vld1q_f32(rows0p);
108             float32x4_t _rows1 = vld1q_f32(rows1p);
109             float32x4_t _D = vmulq_lane_f32(_rows0, _b01, 0);
110             _D = vmlaq_lane_f32(_D, _rows1, _b01, 1);
111             vst1_u16(Dp, vcvt_bf16_f32(_D));
112 
113             Dp += 4;
114             rows0p += 4;
115             rows1p += 4;
116         }
117 
118         beta += 2;
119     }
120 }
121