1// Tencent is pleased to support the open source community by making ncnn available.
2//
3// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
4//
5// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6// in compliance with the License. You may obtain a copy of the License at
7//
8// https://opensource.org/licenses/BSD-3-Clause
9//
10// Unless required by applicable law or agreed to in writing, software distributed
11// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13// specific language governing permissions and limitations under the License.
14
15#version 450
16
17#if NCNN_fp16_storage
18#extension GL_EXT_shader_16bit_storage: require
19struct sfpvec8 { f16vec4 abcd; f16vec4 efgh; };
20#endif
21#if NCNN_fp16_arithmetic
22#extension GL_EXT_shader_explicit_arithmetic_types_float16: require
23#endif
24
25layout (constant_id = 0) const int across_spatial = 0;
26layout (constant_id = 1) const int across_channel = 0;
27
28#if NCNN_image_shader
29layout (binding = 0) uniform unfp sampler3D square_blob;
30layout (binding = 1, rgba32f) writeonly uniform highp image3D sqsum_blob;
31#else
32layout (binding = 0) readonly buffer bottom_top_blob { sfpvec8 bottom_top_blob_data[]; };
33layout (binding = 1) writeonly buffer sqsum_blob { mat2x4 sqsum_blob_data[]; };
34#endif
35
36layout (push_constant) uniform parameter
37{
38    int w;
39    int h;
40    int c;
41    int cstep;
42
43    int outw;
44    int outh;
45    int outc;
46    int outcstep;
47} p;
48
49void main()
50{
51    int gx = int(gl_GlobalInvocationID.x);
52    int gy = int(gl_GlobalInvocationID.y);
53    int gz = int(gl_GlobalInvocationID.z);
54
55    if (gx >= p.outw || gy >= p.outh || gz >= p.outc)
56        return;
57
58    mat2x4 sqsum;
59
60    if (across_spatial == 1 && across_channel == 1)
61    {
62#if NCNN_image_shader
63        int sz = gz * 2;
64        int sy = gy * 2;
65        int sx = gx * 2;
66
67        if (sz == p.c - 1)
68        {
69            if (sy == p.h - 1)
70            {
71                if (sx == p.w - 1)
72                {
73                    mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, sz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, sz), 0));
74
75                    sqsum[0] = v0[0] * v0[0];
76                    sqsum[1] = v0[1] * v0[1];
77                }
78                else
79                {
80                    mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, sz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, sz), 0));
81                    mat2x4 v1 = mat2x4(texelFetch(square_blob, ivec3((sx + 1) * 2, sy, sz), 0), texelFetch(square_blob, ivec3((sx + 1) * 2 + 1, sy, sz), 0));
82
83                    sqsum[0] = v0[0] * v0[0] + v1[0] * v1[0];
84                    sqsum[1] = v0[1] * v0[1] + v1[1] * v1[1];
85                }
86            }
87            else
88            {
89                if (sx == p.w - 1)
90                {
91                    mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, sz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, sz), 0));
92                    mat2x4 v2 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy + 1, sz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy + 1, sz), 0));
93
94                    sqsum[0] = v0[0] * v0[0] + v2[0] * v2[0];
95                    sqsum[1] = v0[1] * v0[1] + v2[1] * v2[1];
96                }
97                else
98                {
99                    mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, sz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, sz), 0));
100                    mat2x4 v1 = mat2x4(texelFetch(square_blob, ivec3((sx + 1) * 2, sy, sz), 0), texelFetch(square_blob, ivec3((sx + 1) * 2 + 1, sy, sz), 0));
101                    mat2x4 v2 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy + 1, sz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy + 1, sz), 0));
102                    mat2x4 v3 = mat2x4(texelFetch(square_blob, ivec3((sx + 1) * 2, sy + 1, sz), 0), texelFetch(square_blob, ivec3((sx + 1) * 2 + 1, sy + 1, sz), 0));
103
104                    sqsum[0] = v0[0] * v0[0] + v1[0] * v1[0] + v2[0] * v2[0] + v3[0] * v3[0];
105                    sqsum[1] = v0[1] * v0[1] + v1[1] * v1[1] + v2[1] * v2[1] + v3[1] * v3[1];
106                }
107            }
108        }
109        else
110        {
111            if (sy == p.h - 1)
112            {
113                if (sx == p.w - 1)
114                {
115                    mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, sz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, sz), 0));
116                    mat2x4 v4 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, sz + 1), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, sz + 1), 0));
117
118                    sqsum[0] = v0[0] * v0[0] + v4[0] * v4[0];
119                    sqsum[1] = v0[1] * v0[1] + v4[1] * v4[1];
120                }
121                else
122                {
123                    mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, sz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, sz), 0));
124                    mat2x4 v1 = mat2x4(texelFetch(square_blob, ivec3((sx + 1) * 2, sy, sz), 0), texelFetch(square_blob, ivec3((sx + 1) * 2 + 1, sy, sz), 0));
125                    mat2x4 v4 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, sz + 1), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, sz + 1), 0));
126                    mat2x4 v5 = mat2x4(texelFetch(square_blob, ivec3((sx + 1) * 2, sy, sz + 1), 0), texelFetch(square_blob, ivec3((sx + 1) * 2 + 1, sy, sz + 1), 0));
127
128                    sqsum[0] = v0[0] * v0[0] + v1[0] * v1[0] + v4[0] * v4[0] + v5[0] * v5[0];
129                    sqsum[1] = v0[1] * v0[1] + v1[1] * v1[1] + v4[1] * v4[1] + v5[1] * v5[1];
130                }
131            }
132            else
133            {
134                if (sx == p.w - 1)
135                {
136                    mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, sz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, sz), 0));
137                    mat2x4 v2 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy + 1, sz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy + 1, sz), 0));
138                    mat2x4 v4 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, sz + 1), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, sz + 1), 0));
139                    mat2x4 v6 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy + 1, sz + 1), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy + 1, sz + 1), 0));
140
141                    sqsum[0] = v0[0] * v0[0] + v2[0] * v2[0] + v4[0] * v4[0] + v6[0] * v6[0];
142                    sqsum[1] = v0[1] * v0[1] + v2[1] * v2[1] + v4[1] * v4[1] + v6[1] * v6[1];
143                }
144                else
145                {
146                    mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, sz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, sz), 0));
147                    mat2x4 v1 = mat2x4(texelFetch(square_blob, ivec3((sx + 1) * 2, sy, sz), 0), texelFetch(square_blob, ivec3((sx + 1) * 2 + 1, sy, sz), 0));
148                    mat2x4 v2 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy + 1, sz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy + 1, sz), 0));
149                    mat2x4 v3 = mat2x4(texelFetch(square_blob, ivec3((sx + 1) * 2, sy + 1, sz), 0), texelFetch(square_blob, ivec3((sx + 1) * 2 + 1, sy + 1, sz), 0));
150                    mat2x4 v4 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, sz + 1), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, sz + 1), 0));
151                    mat2x4 v5 = mat2x4(texelFetch(square_blob, ivec3((sx + 1) * 2, sy, sz + 1), 0), texelFetch(square_blob, ivec3((sx + 1) * 2 + 1, sy, sz + 1), 0));
152                    mat2x4 v6 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy + 1, sz + 1), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy + 1, sz + 1), 0));
153                    mat2x4 v7 = mat2x4(texelFetch(square_blob, ivec3((sx + 1) * 2, sy + 1, sz + 1), 0), texelFetch(square_blob, ivec3((sx + 1) * 2 + 1, sy + 1, sz + 1), 0));
154
155                    sqsum[0] = v0[0] * v0[0] + v1[0] * v1[0] + v2[0] * v2[0] + v3[0] * v3[0] + v4[0] * v4[0] + v5[0] * v5[0] + v6[0] * v6[0] + v7[0] * v7[0];
156                    sqsum[1] = v0[1] * v0[1] + v1[1] * v1[1] + v2[1] * v2[1] + v3[1] * v3[1] + v4[1] * v4[1] + v5[1] * v5[1] + v6[1] * v6[1] + v7[1] * v7[1];
157                }
158            }
159        }
160#else
161        int sz = gz * 2;
162        int sx = gx * 2;
163
164        ivec2 v_offset0 = sz * p.cstep + sx + ivec2(0, 1);
165        ivec2 v_offset1 = v_offset0 + p.cstep;
166
167        if (sz == p.c - 1)
168        {
169            if (sx == p.w - 1)
170            {
171                mat2x4 v0 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset0.r));
172
173                sqsum[0] = v0[0] * v0[0];
174                sqsum[1] = v0[1] * v0[1];
175            }
176            else
177            {
178                mat2x4 v0 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset0.r));
179                mat2x4 v1 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset0.g));
180
181                sqsum[0] = v0[0] * v0[0] + v1[0] * v1[0];
182                sqsum[1] = v0[1] * v0[1] + v1[1] * v1[1];
183            }
184        }
185        else
186        {
187            if (sx == p.w - 1)
188            {
189                mat2x4 v0 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset0.r));
190                mat2x4 v2 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset1.r));
191
192                sqsum[0] = v0[0] * v0[0] + v2[0] * v2[0];
193                sqsum[1] = v0[1] * v0[1] + v2[1] * v2[1];
194            }
195            else
196            {
197                mat2x4 v0 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset0.r));
198                mat2x4 v1 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset0.g));
199                mat2x4 v2 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset1.r));
200                mat2x4 v3 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset1.g));
201
202                sqsum[0] = v0[0] * v0[0] + v1[0] * v1[0] + v2[0] * v2[0] + v3[0] * v3[0];
203                sqsum[1] = v0[1] * v0[1] + v1[1] * v1[1] + v2[1] * v2[1] + v3[1] * v3[1];
204            }
205        }
206#endif
207    }
208
209    if (across_spatial == 1 && across_channel == 0)
210    {
211#if NCNN_image_shader
212        int sy = gy * 2;
213        int sx = gx * 2;
214
215        if (sy == p.h - 1)
216        {
217            if (sx == p.w - 1)
218            {
219                mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, gz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, gz), 0));
220
221                sqsum[0] = v0[0] * v0[0];
222                sqsum[1] = v0[1] * v0[1];
223            }
224            else
225            {
226                mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, gz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, gz), 0));
227                mat2x4 v1 = mat2x4(texelFetch(square_blob, ivec3((sx + 1) * 2, sy, gz), 0), texelFetch(square_blob, ivec3((sx + 1) * 2 + 1, sy, gz), 0));
228
229                sqsum[0] = v0[0] * v0[0] + v1[0] * v1[0];
230                sqsum[1] = v0[1] * v0[1] + v1[1] * v1[1];
231            }
232        }
233        else
234        {
235            if (sx == p.w - 1)
236            {
237                mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, gz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, gz), 0));
238                mat2x4 v2 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy + 1, gz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy + 1, gz), 0));
239
240                sqsum[0] = v0[0] * v0[0] + v2[0] * v2[0];
241                sqsum[1] = v0[1] * v0[1] + v2[1] * v2[1];
242            }
243            else
244            {
245                mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, gz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, gz), 0));
246                mat2x4 v1 = mat2x4(texelFetch(square_blob, ivec3((sx + 1) * 2, sy, gz), 0), texelFetch(square_blob, ivec3((sx + 1) * 2 + 1, sy, gz), 0));
247                mat2x4 v2 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy + 1, gz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy + 1, gz), 0));
248                mat2x4 v3 = mat2x4(texelFetch(square_blob, ivec3((sx + 1) * 2, sy + 1, gz), 0), texelFetch(square_blob, ivec3((sx + 1) * 2 + 1, sy + 1, gz), 0));
249
250                sqsum[0] = v0[0] * v0[0] + v1[0] * v1[0] + v2[0] * v2[0] + v3[0] * v3[0];
251                sqsum[1] = v0[1] * v0[1] + v1[1] * v1[1] + v2[1] * v2[1] + v3[1] * v3[1];
252            }
253        }
254#else
255        int sz = gz;
256        int sx = gx * 4;
257
258        ivec4 v_offset = sz * p.cstep + sx + ivec4(0, 1, 2, 3);
259
260        if (sx == p.w - 1)
261        {
262            mat2x4 v0 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset.r));
263
264            sqsum[0] = v0[0] * v0[0];
265            sqsum[1] = v0[1] * v0[1];
266        }
267        else if (sx == p.w - 2)
268        {
269            mat2x4 v0 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset.r));
270            mat2x4 v1 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset.g));
271
272            sqsum[0] = v0[0] * v0[0] + v1[0] * v1[0];
273            sqsum[1] = v0[1] * v0[1] + v1[1] * v1[1];
274        }
275        else if (sx == p.w - 3)
276        {
277            mat2x4 v0 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset.r));
278            mat2x4 v1 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset.g));
279            mat2x4 v2 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset.b));
280
281            sqsum[0] = v0[0] * v0[0] + v1[0] * v1[0] + v2[0] * v2[0];
282            sqsum[1] = v0[1] * v0[1] + v1[1] * v1[1] + v2[1] * v2[1];
283        }
284        else
285        {
286            mat2x4 v0 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset.r));
287            mat2x4 v1 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset.g));
288            mat2x4 v2 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset.b));
289            mat2x4 v3 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset.a));
290
291            sqsum[0] = v0[0] * v0[0] + v1[0] * v1[0] + v2[0] * v2[0] + v3[0] * v3[0];
292            sqsum[1] = v0[1] * v0[1] + v1[1] * v1[1] + v2[1] * v2[1] + v3[1] * v3[1];
293        }
294#endif
295    }
296
297    if (across_spatial == 0 && across_channel == 1)
298    {
299#if NCNN_image_shader
300        int sz = gz * 4;
301
302        if (sz == p.c - 1)
303        {
304            mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(gx * 2, gy, sz), 0), texelFetch(square_blob, ivec3(gx * 2 + 1, gy, sz), 0));
305
306            sqsum[0] = v0[0] * v0[0];
307            sqsum[1] = v0[1] * v0[1];
308        }
309        else if (sz == p.c - 2)
310        {
311            mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(gx * 2, gy, sz), 0), texelFetch(square_blob, ivec3(gx * 2 + 1, gy, sz), 0));
312            mat2x4 v1 = mat2x4(texelFetch(square_blob, ivec3(gx * 2, gy, sz + 1), 0), texelFetch(square_blob, ivec3(gx * 2 + 1, gy, sz + 1), 0));
313
314            sqsum[0] = v0[0] * v0[0] + v1[0] * v1[0];
315            sqsum[1] = v0[1] * v0[1] + v1[1] * v1[1];
316        }
317        else if (sz == p.c - 3)
318        {
319            mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(gx * 2, gy, sz), 0), texelFetch(square_blob, ivec3(gx * 2 + 1, gy, sz), 0));
320            mat2x4 v1 = mat2x4(texelFetch(square_blob, ivec3(gx * 2, gy, sz + 1), 0), texelFetch(square_blob, ivec3(gx * 2 + 1, gy, sz + 1), 0));
321            mat2x4 v2 = mat2x4(texelFetch(square_blob, ivec3(gx * 2, gy, sz + 2), 0), texelFetch(square_blob, ivec3(gx * 2 + 1, gy, sz + 2), 0));
322
323            sqsum[0] = v0[0] * v0[0] + v1[0] * v1[0] + v2[0] * v2[0];
324            sqsum[1] = v0[1] * v0[1] + v1[1] * v1[1] + v2[1] * v2[1];
325        }
326        else
327        {
328            mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(gx * 2, gy, sz), 0), texelFetch(square_blob, ivec3(gx * 2 + 1, gy, sz), 0));
329            mat2x4 v1 = mat2x4(texelFetch(square_blob, ivec3(gx * 2, gy, sz + 1), 0), texelFetch(square_blob, ivec3(gx * 2 + 1, gy, sz + 1), 0));
330            mat2x4 v2 = mat2x4(texelFetch(square_blob, ivec3(gx * 2, gy, sz + 2), 0), texelFetch(square_blob, ivec3(gx * 2 + 1, gy, sz + 2), 0));
331            mat2x4 v3 = mat2x4(texelFetch(square_blob, ivec3(gx * 2, gy, sz + 3), 0), texelFetch(square_blob, ivec3(gx * 2 + 1, gy, sz + 3), 0));
332
333            sqsum[0] = v0[0] * v0[0] + v1[0] * v1[0] + v2[0] * v2[0] + v3[0] * v3[0];
334            sqsum[1] = v0[1] * v0[1] + v1[1] * v1[1] + v2[1] * v2[1] + v3[1] * v3[1];
335        }
336#else
337        int sz = gz * 4;
338        int sx = gx;
339
340        ivec4 v_offset = (sz + ivec4(0, 1, 2, 3)) * p.cstep + sx;
341
342        if (sz == p.c - 1)
343        {
344            mat2x4 v0 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset.r));
345
346            sqsum[0] = v0[0] * v0[0];
347            sqsum[1] = v0[1] * v0[1];
348        }
349        else if (sz == p.c - 2)
350        {
351            mat2x4 v0 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset.r));
352            mat2x4 v1 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset.g));
353
354            sqsum[0] = v0[0] * v0[0] + v1[0] * v1[0];
355            sqsum[1] = v0[1] * v0[1] + v1[1] * v1[1];
356        }
357        else if (sz == p.c - 3)
358        {
359            mat2x4 v0 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset.r));
360            mat2x4 v1 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset.g));
361            mat2x4 v2 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset.b));
362
363            sqsum[0] = v0[0] * v0[0] + v1[0] * v1[0] + v2[0] * v2[0];
364            sqsum[1] = v0[1] * v0[1] + v1[1] * v1[1] + v2[1] * v2[1];
365        }
366        else
367        {
368            mat2x4 v0 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset.r));
369            mat2x4 v1 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset.g));
370            mat2x4 v2 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset.b));
371            mat2x4 v3 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset.a));
372
373            sqsum[0] = v0[0] * v0[0] + v1[0] * v1[0] + v2[0] * v2[0] + v3[0] * v3[0];
374            sqsum[1] = v0[1] * v0[1] + v1[1] * v1[1] + v2[1] * v2[1] + v3[1] * v3[1];
375        }
376#endif
377    }
378
379#if NCNN_image_shader
380    imageStore(sqsum_blob, ivec3(gx * 2, gy, gz), sqsum[0]);
381    imageStore(sqsum_blob, ivec3(gx * 2 + 1, gy, gz), sqsum[1]);
382#else
383    int gi = gz * p.outcstep + gx;
384
385    sqsum_blob_data[gi] = sqsum;
386#endif
387}
388