1// Tencent is pleased to support the open source community by making ncnn available.
2//
3// Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
4//
5// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6// in compliance with the License. You may obtain a copy of the License at
7//
8// https://opensource.org/licenses/BSD-3-Clause
9//
10// Unless required by applicable law or agreed to in writing, software distributed
11// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13// specific language governing permissions and limitations under the License.
14
15#version 450
16
17#if NCNN_fp16_storage
18#extension GL_EXT_shader_16bit_storage: require
19#endif
20#if NCNN_fp16_arithmetic
21#extension GL_EXT_shader_explicit_arithmetic_types_float16: require
22#endif
23
24layout (constant_id = 0) const int op_type = 0;
25
26#define shape_constant_id_offset 1
27layout (constant_id = shape_constant_id_offset + 0) const int adims = 0;
28layout (constant_id = shape_constant_id_offset + 1) const int aw = 0;
29layout (constant_id = shape_constant_id_offset + 2) const int ah = 0;
30layout (constant_id = shape_constant_id_offset + 3) const int ac = 0;
31layout (constant_id = shape_constant_id_offset + 4) const int acstep = 0;
32
33layout (constant_id = shape_constant_id_offset + 5) const int bdims = 0;
34layout (constant_id = shape_constant_id_offset + 6) const int bw = 0;
35layout (constant_id = shape_constant_id_offset + 7) const int bh = 0;
36layout (constant_id = shape_constant_id_offset + 8) const int bc = 0;
37layout (constant_id = shape_constant_id_offset + 9) const int bcstep = 0;
38
39layout (constant_id = shape_constant_id_offset + 10) const int outdims = 0;
40layout (constant_id = shape_constant_id_offset + 11) const int outw = 0;
41layout (constant_id = shape_constant_id_offset + 12) const int outh = 0;
42layout (constant_id = shape_constant_id_offset + 13) const int outc = 0;
43layout (constant_id = shape_constant_id_offset + 14) const int outcstep = 0;
44
45#if NCNN_image_shader
46layout (binding = 0) uniform unfp sampler3D a_blob_3d;
47layout (binding = 1) uniform unfp sampler3D b_blob_3d;
48layout (binding = 2, imfmtc4) writeonly uniform unfp image3D top_blob_3d;
49#else
50layout (binding = 0) readonly buffer a_blob { sfpvec4 a_blob_data[]; };
51layout (binding = 1) readonly buffer b_blob { sfpvec4 b_blob_data[]; };
52layout (binding = 2) writeonly buffer top_blob { sfpvec4 top_blob_data[]; };
53#endif
54
55layout (push_constant) uniform parameter
56{
57    int adims;
58    int aw;
59    int ah;
60    int ac;
61    int acstep;
62
63    int bdims;
64    int bw;
65    int bh;
66    int bc;
67    int bcstep;
68
69    int outdims;
70    int outw;
71    int outh;
72    int outc;
73    int outcstep;
74} p;
75
76void main()
77{
78    int gx = int(gl_GlobalInvocationID.x);
79    int gy = int(gl_GlobalInvocationID.y);
80    int gz = int(gl_GlobalInvocationID.z);
81
82    if (gx >= psc(outw) || gy >= psc(outh) || gz >= psc(outc))
83        return;
84
85#if NCNN_image_shader
86    int ax = gx;
87    int ay = gy;
88    int az = gz;
89    int bx = gx;
90    int by = gy;
91    int bz = gz;
92
93    if (psc(adims) == 3)
94    {
95        if (psc(bdims) == 3)
96        {
97            if (psc(bw) == 1 && psc(bh) == 1)
98            {
99                // special type 1
100                bx = 0;
101                by = 0;
102            }
103
104            if (psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(bc) == 1)
105            {
106                // special type 2
107                bz = 0;
108            }
109
110            if (psc(aw) == 1 && psc(ah) == 1)
111            {
112                // special type 3
113                ax = 0;
114                ay = 0;
115            }
116
117            if (psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(ac) == 1)
118            {
119                // special type 4
120                az = 0;
121            }
122
123            if (psc(aw) != 1 && psc(bw) == 1 && psc(bh) == psc(ah) && psc(bc) == psc(ac))
124            {
125                // special type 5
126                bx = 0;
127            }
128
129            if (psc(bw) == psc(aw) && psc(ah) != 1 && psc(bh) == 1 && psc(bc) == psc(ac))
130            {
131                // special type 6
132                by = 0;
133            }
134
135            if (psc(bw) != 1 && psc(aw) == 1 && psc(bh) == psc(ah) && psc(bc) == psc(ac))
136            {
137                // special type 7
138                ax = 0;
139            }
140
141            if (psc(bw) == psc(aw) && psc(bh) != 1 && psc(ah) == 1 && psc(bc) == psc(ac))
142            {
143                // special type 8
144                ay = 0;
145            }
146        }
147
148        if (psc(bdims) == 2)
149        {
150            // type 18
151            bx = gy;
152            by = gz;
153            bz = 0;
154        }
155
156        if (psc(bdims) == 1)
157        {
158            if (psc(bw) == 1)
159            {
160                // type 16
161                bx = 0;
162                by = 0;
163                bz = 0;
164            }
165            else
166            {
167                // type 17
168                bx = gz;
169                by = 0;
170                bz = 0;
171            }
172        }
173    }
174    else if (psc(adims) == 2)
175    {
176        if (psc(bdims) == 3)
177        {
178            // type 14
179            ax = gy;
180            ay = gz;
181            az = 0;
182        }
183
184        if (psc(bdims) == 1)
185        {
186            if (psc(bw) == 1)
187            {
188                // type 11
189                bx = 0;
190                by = 0;
191                bz = 0;
192            }
193            else
194            {
195                // type 12
196                bx = gy;
197                by = 0;
198                bz = 0;
199            }
200        }
201    }
202    else if (psc(adims) == 1)
203    {
204        if (psc(aw) == 1)
205        {
206            // type 2 3 4
207            ax = 0;
208            ay = 0;
209            az = 0;
210        }
211        else
212        {
213            if (psc(bdims) == 3)
214            {
215                // type 9
216                ax = gz;
217                ay = 0;
218                az = 0;
219            }
220
221            if (psc(bdims) == 2)
222            {
223                // type 8
224                ax = gy;
225                ay = 0;
226                az = 0;
227            }
228
229            if (psc(bdims) == 1)
230            {
231                if (psc(bw) == 1)
232                {
233                    // type 6
234                    bx = 0;
235                    by = 0;
236                    bz = 0;
237                }
238            }
239        }
240    }
241
242    afpvec4 v1 = image3d_ld4(a_blob_3d, ivec3(ax, ay, az));
243    afpvec4 v2 = image3d_ld4(b_blob_3d, ivec3(bx, by, bz));
244#else
245    const int gi = gz * psc(outcstep) + gy * psc(outw) + gx;
246
247    int ai;
248    int bi;
249
250    if (psc(adims) == 3)
251    {
252        if (psc(bdims) == 3)
253        {
254            if (psc(bw) == 1 && psc(bh) == 1)
255            {
256                // special type 1
257                ai = gi;
258                bi = gz * psc(bcstep);
259            }
260
261            if (psc(aw) == 1 && psc(ah) == 1)
262            {
263                // special type 3
264                ai = gz * psc(acstep);
265                bi = gi;
266            }
267
268            if (psc(aw) != 1 && psc(bw) == 1 && psc(bh) == psc(ah) && psc(bc) == psc(ac))
269            {
270                // special type 5
271                ai = gi;
272                bi = gz * psc(bcstep) + gy;
273            }
274
275            if (psc(bw) == psc(aw) && psc(ah) != 1 && psc(bh) == 1 && psc(bc) == psc(ac))
276            {
277                // special type 6
278                ai = gi;
279                bi = gz * psc(bcstep) + gx;
280            }
281
282            if (psc(bw) != 1 && psc(aw) == 1 && psc(bh) == psc(ah) && psc(bc) == psc(ac))
283            {
284                // special type 7
285                ai = gz * psc(acstep) + gy;
286                bi = gi;
287            }
288
289            if (psc(bw) == psc(aw) && psc(bh) != 1 && psc(ah) == 1 && psc(bc) == psc(ac))
290            {
291                // special type 8
292                ai = gz * psc(acstep) + gx;
293                bi = gi;
294            }
295        }
296
297        if (psc(bdims) == 2)
298        {
299            // type 18
300            ai = gi;
301            bi = gz * psc(bw) + gy;
302        }
303
304        if (psc(bdims) == 1)
305        {
306            // type 17
307            ai = gi;
308            bi = gz;
309        }
310    }
311    else if (psc(adims) == 2)
312    {
313        if (psc(bdims) == 3)
314        {
315            // type 14
316            ai = gz * psc(aw) + gy;
317            bi = gi;
318        }
319
320        if (psc(bdims) == 1)
321        {
322            // type 12
323            ai = gi;
324            bi = gy;
325        }
326    }
327    else if (psc(adims) == 1)
328    {
329        if (psc(bdims) == 3)
330        {
331            // type 9
332            ai = gz;
333            bi = gi;
334        }
335
336        if (psc(bdims) == 2)
337        {
338            // type 8
339            ai = gy;
340            bi = gi;
341        }
342    }
343
344    afpvec4 v1 = buffer_ld4(a_blob_data, ai);
345    afpvec4 v2 = buffer_ld4(b_blob_data, bi);
346#endif
347
348    afpvec4 res;
349
350    if (op_type == 0) res = v1 + v2;
351    if (op_type == 1) res = v1 - v2;
352    if (op_type == 2) res = v1 * v2;
353    if (op_type == 3) res = v1 / v2;
354    if (op_type == 4) res = max(v1, v2);
355    if (op_type == 5) res = min(v1, v2);
356    if (op_type == 6) res = pow(v1, v2);
357    if (op_type == 7) res = v2 - v1;
358    if (op_type == 8) res = v2 / v1;
359
360#if NCNN_image_shader
361    image3d_st4(top_blob_3d, ivec3(gx, gy, gz), res);
362#else
363    buffer_st4(top_blob_data, gi, res);
364#endif
365}
366