1// Tencent is pleased to support the open source community by making ncnn available.
2//
3// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
4//
5// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6// in compliance with the License. You may obtain a copy of the License at
7//
8// https://opensource.org/licenses/BSD-3-Clause
9//
10// Unless required by applicable law or agreed to in writing, software distributed
11// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12// CONDITIONS OF ANY KIND, either express or implied. See the License for the
13// specific language governing permissions and limitations under the License.
14
15#version 450
16
17#if NCNN_fp16_storage
18#extension GL_EXT_shader_16bit_storage: require
19struct sfpvec8 { f16vec4 abcd; f16vec4 efgh; };
20#endif
21#if NCNN_fp16_arithmetic
22#extension GL_EXT_shader_explicit_arithmetic_types_float16: require
23#endif
24
25layout (constant_id = 0) const int ndim = 0;
26
27#define shape_constant_id_offset 1
28layout (constant_id = shape_constant_id_offset + 0) const int dims = 0;
29layout (constant_id = shape_constant_id_offset + 1) const int w = 0;
30layout (constant_id = shape_constant_id_offset + 2) const int h = 0;
31layout (constant_id = shape_constant_id_offset + 3) const int c = 0;
32layout (constant_id = shape_constant_id_offset + 4) const int cstep = 0;
33
34layout (constant_id = shape_constant_id_offset + 5) const int outdims = 0;
35layout (constant_id = shape_constant_id_offset + 6) const int outw = 0;
36layout (constant_id = shape_constant_id_offset + 7) const int outh = 0;
37layout (constant_id = shape_constant_id_offset + 8) const int outc = 0;
38layout (constant_id = shape_constant_id_offset + 9) const int outcstep = 0;
39
40#if NCNN_image_shader
41layout (binding = 0) uniform unfp sampler3D bottom_blob_3d;
42layout (binding = 1, imfmtc4) writeonly uniform unfp image3D top_blob_3d;
43#else
44#if NCNN_fp16_packed
45layout (binding = 0) readonly buffer bottom_blob { sfpvec2 bottom_blob_data[]; };
46#else
47layout (binding = 0) readonly buffer bottom_blob { sfp bottom_blob_data[]; };
48#endif
49layout (binding = 1) writeonly buffer top_blob { sfpvec8 top_blob_data[]; };
50#endif
51
52layout (push_constant) uniform parameter
53{
54    int dims;
55    int w;
56    int h;
57    int c;
58    int cstep;
59
60    int outdims;
61    int outw;
62    int outh;
63    int outc;
64    int outcstep;
65} p;
66
67void main()
68{
69    int gx = int(gl_GlobalInvocationID.x);
70    int gy = int(gl_GlobalInvocationID.y);
71    int gz = int(gl_GlobalInvocationID.z);
72
73    if (gx >= psc(outw) || gy >= psc(outh) || gz >= psc(outc))
74        return;
75
76    ivec4 i4;
77    ivec4 ii4;
78
79    if (ndim == 1)
80    {
81        i4 = gx * 8 + ivec4(0, 1, 2, 3);
82        ii4 = i4 + 4;
83    }
84    if (ndim == 2)
85    {
86        i4 = (gy * 8) * psc(outw) + gx + ivec4(0, 1, 2, 3) * psc(outw);
87        ii4 = i4 + 4 * psc(outw);
88    }
89    if (ndim == 3)
90    {
91        i4 = (gz * 8) * psc(outh) * psc(outw) + gy * psc(outw) + gx + ivec4(0, 1, 2, 3) * psc(outh) * psc(outw);
92        ii4 = i4 + 4 * psc(outh) * psc(outw);
93    }
94
95#if NCNN_image_shader
96    afpvec8 v;
97
98    if (psc(dims) == 1)
99    {
100        ivec4 x4 = i4;
101        ivec4 xx4 = ii4;
102
103        afpvec4 v0 = image3d_ld4(bottom_blob_3d, ivec3(x4.r / 4, 0, 0));
104        afpvec4 v1 = image3d_ld4(bottom_blob_3d, ivec3(x4.g / 4, 0, 0));
105        afpvec4 v2 = image3d_ld4(bottom_blob_3d, ivec3(x4.b / 4, 0, 0));
106        afpvec4 v3 = image3d_ld4(bottom_blob_3d, ivec3(x4.a / 4, 0, 0));
107        afpvec4 v4 = image3d_ld4(bottom_blob_3d, ivec3(xx4.r / 4, 0, 0));
108        afpvec4 v5 = image3d_ld4(bottom_blob_3d, ivec3(xx4.g / 4, 0, 0));
109        afpvec4 v6 = image3d_ld4(bottom_blob_3d, ivec3(xx4.b / 4, 0, 0));
110        afpvec4 v7 = image3d_ld4(bottom_blob_3d, ivec3(xx4.a / 4, 0, 0));
111
112        v[0].r = v0[x4.r % 4];
113        v[0].g = v1[x4.g % 4];
114        v[0].b = v2[x4.b % 4];
115        v[0].a = v3[x4.a % 4];
116        v[1].r = v4[xx4.r % 4];
117        v[1].g = v5[xx4.g % 4];
118        v[1].b = v6[xx4.b % 4];
119        v[1].a = v7[xx4.a % 4];
120    }
121    else if (psc(dims) == 2)
122    {
123        ivec4 y4 = i4 / psc(w);
124        ivec4 x4 = i4 % psc(w);
125        ivec4 yy4 = ii4 / psc(w);
126        ivec4 xx4 = ii4 % psc(w);
127
128        afpvec4 v0 = image3d_ld4(bottom_blob_3d, ivec3(x4.r, y4.r / 4, 0));
129        afpvec4 v1 = image3d_ld4(bottom_blob_3d, ivec3(x4.g, y4.g / 4, 0));
130        afpvec4 v2 = image3d_ld4(bottom_blob_3d, ivec3(x4.b, y4.b / 4, 0));
131        afpvec4 v3 = image3d_ld4(bottom_blob_3d, ivec3(x4.a, y4.a / 4, 0));
132        afpvec4 v4 = image3d_ld4(bottom_blob_3d, ivec3(xx4.r, yy4.r / 4, 0));
133        afpvec4 v5 = image3d_ld4(bottom_blob_3d, ivec3(xx4.g, yy4.g / 4, 0));
134        afpvec4 v6 = image3d_ld4(bottom_blob_3d, ivec3(xx4.b, yy4.b / 4, 0));
135        afpvec4 v7 = image3d_ld4(bottom_blob_3d, ivec3(xx4.a, yy4.a / 4, 0));
136
137        v[0].r = v0[y4.r % 4];
138        v[0].g = v1[y4.g % 4];
139        v[0].b = v2[y4.b % 4];
140        v[0].a = v3[y4.a % 4];
141        v[1].r = v4[yy4.r % 4];
142        v[1].g = v5[yy4.g % 4];
143        v[1].b = v6[yy4.b % 4];
144        v[1].a = v7[yy4.a % 4];
145    }
146    else // if (psc(dims) == 3)
147    {
148        int size = psc(w) * psc(h);
149
150        ivec4 z4 = i4 / size;
151        ivec4 y4 = i4 % size / psc(w);
152        ivec4 x4 = i4 % size % psc(w);
153        ivec4 zz4 = ii4 / size;
154        ivec4 yy4 = ii4 % size / psc(w);
155        ivec4 xx4 = ii4 % size % psc(w);
156
157        afpvec4 v0 = image3d_ld4(bottom_blob_3d, ivec3(x4.r, y4.r, z4.r / 4));
158        afpvec4 v1 = image3d_ld4(bottom_blob_3d, ivec3(x4.g, y4.g, z4.g / 4));
159        afpvec4 v2 = image3d_ld4(bottom_blob_3d, ivec3(x4.b, y4.b, z4.b / 4));
160        afpvec4 v3 = image3d_ld4(bottom_blob_3d, ivec3(x4.a, y4.a, z4.a / 4));
161        afpvec4 v4 = image3d_ld4(bottom_blob_3d, ivec3(xx4.r, yy4.r, zz4.r / 4));
162        afpvec4 v5 = image3d_ld4(bottom_blob_3d, ivec3(xx4.g, yy4.g, zz4.g / 4));
163        afpvec4 v6 = image3d_ld4(bottom_blob_3d, ivec3(xx4.b, yy4.b, zz4.b / 4));
164        afpvec4 v7 = image3d_ld4(bottom_blob_3d, ivec3(xx4.a, yy4.a, zz4.a / 4));
165
166        v[0].r = v0[z4.r % 4];
167        v[0].g = v1[z4.g % 4];
168        v[0].b = v2[z4.b % 4];
169        v[0].a = v3[z4.a % 4];
170        v[1].r = v4[zz4.r % 4];
171        v[1].g = v5[zz4.g % 4];
172        v[1].b = v6[zz4.b % 4];
173        v[1].a = v7[zz4.a % 4];
174    }
175
176    if (ndim == 1)
177    {
178        image3d_st8(top_blob_3d, ivec3(gx, 0, 0), v);
179    }
180    if (ndim == 2)
181    {
182        image3d_st8(top_blob_3d, ivec3(gx, gy, 0), v);
183    }
184    if (ndim == 3)
185    {
186        image3d_st8(top_blob_3d, ivec3(gx, gy, gz), v);
187    }
188#else
189#if NCNN_fp16_packed
190    ivec4 v_offset;
191    ivec4 vv_offset;
192    ivec4 lane2;
193    ivec4 lane4;
194
195    if (psc(dims) == 1)
196    {
197        v_offset = i4 / 2;
198        lane2 = i4 % 2;
199        vv_offset = ii4 / 2;
200        lane4 = ii4 % 2;
201    }
202    else if (psc(dims) == 2)
203    {
204        ivec4 y4 = i4 / psc(w);
205        ivec4 x4 = i4 % psc(w);
206        ivec4 yy4 = ii4 / psc(w);
207        ivec4 xx4 = ii4 % psc(w);
208
209        v_offset = ((y4 / 4) * psc(w) + x4) * 2 + (y4 % 4) / 2;
210        lane2 = y4 % 2;
211        vv_offset = ((yy4 / 4) * psc(w) + xx4) * 2 + (yy4 % 4) / 2;
212        lane4 = yy4 % 2;
213    }
214    else // if (psc(dims) == 3)
215    {
216        int size = psc(w) * psc(h);
217
218        ivec4 z4 = i4 / size;
219        ivec4 y4 = i4 % size / psc(w);
220        ivec4 x4 = i4 % size % psc(w);
221        ivec4 zz4 = ii4 / size;
222        ivec4 yy4 = ii4 % size / psc(w);
223        ivec4 xx4 = ii4 % size % psc(w);
224
225        v_offset = ((z4 / 4) * psc(cstep) + y4 * psc(w) + x4) * 2 + (z4 % 4) / 2;
226        lane2 = z4 % 2;
227        vv_offset = ((zz4 / 4) * psc(cstep) + yy4 * psc(w) + xx4) * 2 + (zz4 % 4) / 2;
228        lane4 = zz4 % 2;
229    }
230
231    int gi;
232
233    if (ndim == 1) gi = gx;
234    if (ndim == 2) gi = gy * psc(outw) + gx;
235    if (ndim == 3) gi = gz * psc(outcstep) + gy * psc(outw) + gx;
236
237    afpvec2 vr = buffer_ld2(bottom_blob_data, v_offset.r);
238    afpvec2 vg = buffer_ld2(bottom_blob_data, v_offset.g);
239    afpvec2 vb = buffer_ld2(bottom_blob_data, v_offset.b);
240    afpvec2 va = buffer_ld2(bottom_blob_data, v_offset.a);
241
242    afpvec2 vvr = buffer_ld2(bottom_blob_data, vv_offset.r);
243    afpvec2 vvg = buffer_ld2(bottom_blob_data, vv_offset.g);
244    afpvec2 vvb = buffer_ld2(bottom_blob_data, vv_offset.b);
245    afpvec2 vva = buffer_ld2(bottom_blob_data, vv_offset.a);
246
247    afpvec8 v = afpvec8(vr[lane2.r], vg[lane2.g], vb[lane2.b], va[lane2.a], vvr[lane4.r], vvg[lane4.g], vvb[lane4.b], vva[lane4.a]);
248
249    buffer_st8(top_blob_data, gi, v);
250#else
251    ivec4 v_offset;
252    ivec4 vv_offset;
253
254    if (psc(dims) == 1)
255    {
256        v_offset = i4;
257        vv_offset = ii4;
258    }
259    else if (psc(dims) == 2)
260    {
261        ivec4 y4 = i4 / psc(w);
262        ivec4 x4 = i4 % psc(w);
263        ivec4 yy4 = ii4 / psc(w);
264        ivec4 xx4 = ii4 % psc(w);
265
266        v_offset = ((y4 / 4) * psc(w) + x4) * 4 + y4 % 4;
267        vv_offset = ((yy4 / 4) * psc(w) + xx4) * 4 + yy4 % 4;
268    }
269    else // if (psc(dims) == 3)
270    {
271        int size = psc(w) * psc(h);
272
273        ivec4 z4 = i4 / size;
274        ivec4 y4 = i4 % size / psc(w);
275        ivec4 x4 = i4 % size % psc(w);
276        ivec4 zz4 = ii4 / size;
277        ivec4 yy4 = ii4 % size / psc(w);
278        ivec4 xx4 = ii4 % size % psc(w);
279
280        v_offset = ((z4 / 4) * psc(cstep) + y4 * psc(w) + x4) * 4 + z4 % 4;
281        vv_offset = ((zz4 / 4) * psc(cstep) + yy4 * psc(w) + xx4) * 4 + zz4 % 4;
282    }
283
284    int gi;
285
286    if (ndim == 1) gi = gx;
287    if (ndim == 2) gi = gy * psc(outw) + gx;
288    if (ndim == 3) gi = gz * psc(outcstep) + gy * psc(outw) + gx;
289
290    buffer_cp1to8(top_blob_data, gi, bottom_blob_data, v_offset, vv_offset);
291#endif
292#endif
293}
294