1// Tencent is pleased to support the open source community by making ncnn available.
2//
3// Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
4//
5// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6// in compliance with the License. You may obtain a copy of the License at
7//
8// https://opensource.org/licenses/BSD-3-Clause
9//
10// Unless required by applicable law or agreed to in writing, software distributed
11// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13// specific language governing permissions and limitations under the License.
14
15#version 450
16
17#if NCNN_fp16_storage
18#extension GL_EXT_shader_16bit_storage: require
19#endif
20#if NCNN_fp16_arithmetic
21#extension GL_EXT_shader_explicit_arithmetic_types_float16: require
22#endif
23
24layout (constant_id = 0) const int across_spatial = 0;
25layout (constant_id = 1) const int across_channel = 0;
26
27#if NCNN_image_shader
28layout (binding = 0) uniform highp sampler3D square_blob;
29layout (binding = 1, rgba32f) writeonly uniform highp image3D sqsum_blob;
30#else
31layout (binding = 0) readonly buffer square_blob { vec4 square_blob_data[]; };
32layout (binding = 1) writeonly buffer sqsum_blob { vec4 sqsum_blob_data[]; };
33#endif
34
35layout (push_constant) uniform parameter
36{
37    int w;
38    int h;
39    int c;
40    int cstep;
41
42    int outw;
43    int outh;
44    int outc;
45    int outcstep;
46} p;
47
48void main()
49{
50    int gx = int(gl_GlobalInvocationID.x);
51    int gy = int(gl_GlobalInvocationID.y);
52    int gz = int(gl_GlobalInvocationID.z);
53
54    if (gx >= p.outw || gy >= p.outh || gz >= p.outc)
55        return;
56
57    vec4 sqsum;
58
59    if (across_spatial == 1 && across_channel == 1)
60    {
61#if NCNN_image_shader
62        int sz = gz * 2;
63        int sy = gy * 2;
64        int sx = gx * 2;
65
66        if (sz == p.c - 1)
67        {
68            if (sy == p.h - 1)
69            {
70                if (sx == p.w - 1)
71                {
72                    vec4 v0 = texelFetch(square_blob, ivec3(sx, sy, sz), 0);
73
74                    sqsum = v0;
75                }
76                else
77                {
78                    vec4 v0 = texelFetch(square_blob, ivec3(sx, sy, sz), 0);
79                    vec4 v1 = texelFetch(square_blob, ivec3(sx + 1, sy, sz), 0);
80
81                    sqsum = v0 + v1;
82                }
83            }
84            else
85            {
86                if (sx == p.w - 1)
87                {
88                    vec4 v0 = texelFetch(square_blob, ivec3(sx, sy, sz), 0);
89                    vec4 v2 = texelFetch(square_blob, ivec3(sx, sy + 1, sz), 0);
90
91                    sqsum = v0 + v2;
92                }
93                else
94                {
95                    vec4 v0 = texelFetch(square_blob, ivec3(sx, sy, sz), 0);
96                    vec4 v1 = texelFetch(square_blob, ivec3(sx + 1, sy, sz), 0);
97                    vec4 v2 = texelFetch(square_blob, ivec3(sx, sy + 1, sz), 0);
98                    vec4 v3 = texelFetch(square_blob, ivec3(sx + 1, sy + 1, sz), 0);
99
100                    sqsum = v0 + v1 + v2 + v3;
101                }
102            }
103        }
104        else
105        {
106            if (sy == p.h - 1)
107            {
108                if (sx == p.w - 1)
109                {
110                    vec4 v0 = texelFetch(square_blob, ivec3(sx, sy, sz), 0);
111                    vec4 v4 = texelFetch(square_blob, ivec3(sx, sy, sz + 1), 0);
112
113                    sqsum = v0 + v4;
114                }
115                else
116                {
117                    vec4 v0 = texelFetch(square_blob, ivec3(sx, sy, sz), 0);
118                    vec4 v1 = texelFetch(square_blob, ivec3(sx + 1, sy, sz), 0);
119                    vec4 v4 = texelFetch(square_blob, ivec3(sx, sy, sz + 1), 0);
120                    vec4 v5 = texelFetch(square_blob, ivec3(sx + 1, sy, sz + 1), 0);
121
122                    sqsum = v0 + v1 + v4 + v5;
123                }
124            }
125            else
126            {
127                if (sx == p.w - 1)
128                {
129                    vec4 v0 = texelFetch(square_blob, ivec3(sx, sy, sz), 0);
130                    vec4 v2 = texelFetch(square_blob, ivec3(sx, sy + 1, sz), 0);
131                    vec4 v4 = texelFetch(square_blob, ivec3(sx, sy, sz + 1), 0);
132                    vec4 v6 = texelFetch(square_blob, ivec3(sx, sy + 1, sz + 1), 0);
133
134                    sqsum = v0 + v2 + v4 + v6;
135                }
136                else
137                {
138                    vec4 v0 = texelFetch(square_blob, ivec3(sx, sy, sz), 0);
139                    vec4 v1 = texelFetch(square_blob, ivec3(sx + 1, sy, sz), 0);
140                    vec4 v2 = texelFetch(square_blob, ivec3(sx, sy + 1, sz), 0);
141                    vec4 v3 = texelFetch(square_blob, ivec3(sx + 1, sy + 1, sz), 0);
142                    vec4 v4 = texelFetch(square_blob, ivec3(sx, sy, sz + 1), 0);
143                    vec4 v5 = texelFetch(square_blob, ivec3(sx + 1, sy, sz + 1), 0);
144                    vec4 v6 = texelFetch(square_blob, ivec3(sx, sy + 1, sz + 1), 0);
145                    vec4 v7 = texelFetch(square_blob, ivec3(sx + 1, sy + 1, sz + 1), 0);
146
147                    sqsum = v0 + v1 + v2 + v3 + v4 + v5 + v6 + v7;
148                }
149            }
150        }
151#else
152        int sz = gz * 2;
153        int sx = gx * 2;
154
155        ivec2 v_offset0 = sz * p.cstep + sx + ivec2(0, 1);
156        ivec2 v_offset1 = v_offset0 + p.cstep;
157
158        if (sz == p.c - 1)
159        {
160            if (sx == p.w - 1)
161            {
162                vec4 v0 = square_blob_data[v_offset0.r];
163
164                sqsum = v0;
165            }
166            else
167            {
168                vec4 v0 = square_blob_data[v_offset0.r];
169                vec4 v1 = square_blob_data[v_offset0.g];
170
171                sqsum = v0 + v1;
172            }
173        }
174        else
175        {
176            if (sx == p.w - 1)
177            {
178                vec4 v0 = square_blob_data[v_offset0.r];
179                vec4 v2 = square_blob_data[v_offset1.r];
180
181                sqsum = v0 + v2;
182            }
183            else
184            {
185                vec4 v0 = square_blob_data[v_offset0.r];
186                vec4 v1 = square_blob_data[v_offset0.g];
187                vec4 v2 = square_blob_data[v_offset1.r];
188                vec4 v3 = square_blob_data[v_offset1.g];
189
190                sqsum = v0 + v1 + v2 + v3;
191            }
192        }
193#endif
194    }
195
196    if (across_spatial == 1 && across_channel == 0)
197    {
198#if NCNN_image_shader
199        int sy = gy * 2;
200        int sx = gx * 2;
201
202        if (sy == p.h - 1)
203        {
204            if (sx == p.w - 1)
205            {
206                vec4 v0 = texelFetch(square_blob, ivec3(sx, sy, gz), 0);
207
208                sqsum = v0;
209            }
210            else
211            {
212                vec4 v0 = texelFetch(square_blob, ivec3(sx, sy, gz), 0);
213                vec4 v1 = texelFetch(square_blob, ivec3(sx + 1, sy, gz), 0);
214
215                sqsum = v0 + v1;
216            }
217        }
218        else
219        {
220            if (sx == p.w - 1)
221            {
222                vec4 v0 = texelFetch(square_blob, ivec3(sx, sy, gz), 0);
223                vec4 v2 = texelFetch(square_blob, ivec3(sx, sy + 1, gz), 0);
224
225                sqsum = v0 + v2;
226            }
227            else
228            {
229                vec4 v0 = texelFetch(square_blob, ivec3(sx, sy, gz), 0);
230                vec4 v1 = texelFetch(square_blob, ivec3(sx + 1, sy, gz), 0);
231                vec4 v2 = texelFetch(square_blob, ivec3(sx, sy + 1, gz), 0);
232                vec4 v3 = texelFetch(square_blob, ivec3(sx + 1, sy + 1, gz), 0);
233
234                sqsum = v0 + v1 + v2 + v3;
235            }
236        }
237#else
238        int sz = gz;
239        int sx = gx * 4;
240
241        ivec4 v_offset = sz * p.cstep + sx + ivec4(0, 1, 2, 3);
242
243        if (sx == p.w - 1)
244        {
245            vec4 v0 = square_blob_data[v_offset.r];
246
247            sqsum = v0;
248        }
249        else if (sx == p.w - 2)
250        {
251            vec4 v0 = square_blob_data[v_offset.r];
252            vec4 v1 = square_blob_data[v_offset.g];
253
254            sqsum = v0 + v1;
255        }
256        else if (sx == p.w - 3)
257        {
258            vec4 v0 = square_blob_data[v_offset.r];
259            vec4 v1 = square_blob_data[v_offset.g];
260            vec4 v2 = square_blob_data[v_offset.b];
261
262            sqsum = v0 + v1 + v2;
263        }
264        else
265        {
266            vec4 v0 = square_blob_data[v_offset.r];
267            vec4 v1 = square_blob_data[v_offset.g];
268            vec4 v2 = square_blob_data[v_offset.b];
269            vec4 v3 = square_blob_data[v_offset.a];
270
271            sqsum = v0 + v1 + v2 + v3;
272        }
273#endif
274    }
275
276    if (across_spatial == 0 && across_channel == 1)
277    {
278#if NCNN_image_shader
279        int sz = gz * 4;
280
281        if (sz == p.c - 1)
282        {
283            vec4 v0 = texelFetch(square_blob, ivec3(gx, gy, sz), 0);
284
285            sqsum = v0;
286        }
287        else if (sz == p.c - 2)
288        {
289            vec4 v0 = texelFetch(square_blob, ivec3(gx, gy, sz), 0);
290            vec4 v1 = texelFetch(square_blob, ivec3(gx, gy, sz + 1), 0);
291
292            sqsum = v0 + v1;
293        }
294        else if (sz == p.c - 3)
295        {
296            vec4 v0 = texelFetch(square_blob, ivec3(gx, gy, sz), 0);
297            vec4 v1 = texelFetch(square_blob, ivec3(gx, gy, sz + 1), 0);
298            vec4 v2 = texelFetch(square_blob, ivec3(gx, gy, sz + 2), 0);
299
300            sqsum = v0 + v1 + v2;
301        }
302        else
303        {
304            vec4 v0 = texelFetch(square_blob, ivec3(gx, gy, sz), 0);
305            vec4 v1 = texelFetch(square_blob, ivec3(gx, gy, sz + 1), 0);
306            vec4 v2 = texelFetch(square_blob, ivec3(gx, gy, sz + 2), 0);
307            vec4 v3 = texelFetch(square_blob, ivec3(gx, gy, sz + 3), 0);
308
309            sqsum = v0 + v1 + v2 + v3;
310        }
311#else
312        int sz = gz * 4;
313        int sx = gx;
314
315        ivec4 v_offset = (sz + ivec4(0, 1, 2, 3)) * p.cstep + sx;
316
317        if (sz == p.c - 1)
318        {
319            vec4 v0 = square_blob_data[v_offset.r];
320
321            sqsum = v0;
322        }
323        else if (sz == p.c - 2)
324        {
325            vec4 v0 = square_blob_data[v_offset.r];
326            vec4 v1 = square_blob_data[v_offset.g];
327
328            sqsum = v0 + v1;
329        }
330        else if (sz == p.c - 3)
331        {
332            vec4 v0 = square_blob_data[v_offset.r];
333            vec4 v1 = square_blob_data[v_offset.g];
334            vec4 v2 = square_blob_data[v_offset.b];
335
336            sqsum = v0 + v1 + v2;
337        }
338        else
339        {
340            vec4 v0 = square_blob_data[v_offset.r];
341            vec4 v1 = square_blob_data[v_offset.g];
342            vec4 v2 = square_blob_data[v_offset.b];
343            vec4 v3 = square_blob_data[v_offset.a];
344
345            sqsum = v0 + v1 + v2 + v3;
346        }
347#endif
348    }
349
350#if NCNN_image_shader
351    imageStore(sqsum_blob, ivec3(gx, gy, gz), sqsum);
352#else
353    int gi = gz * p.outcstep + gx;
354
355    sqsum_blob_data[gi] = sqsum;
356#endif
357}
358