1// Tencent is pleased to support the open source community by making ncnn available. 2// 3// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved. 4// 5// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except 6// in compliance with the License. You may obtain a copy of the License at 7// 8// https://opensource.org/licenses/BSD-3-Clause 9// 10// Unless required by applicable law or agreed to in writing, software distributed 11// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 12// CONDITIONS OF ANY KIND, either express or implied. See the License for the 13// specific language governing permissions and limitations under the License. 14 15#version 450 16 17#if NCNN_fp16_storage 18#extension GL_EXT_shader_16bit_storage: require 19struct sfpvec8 { f16vec4 abcd; f16vec4 efgh; }; 20#endif 21#if NCNN_fp16_arithmetic 22#extension GL_EXT_shader_explicit_arithmetic_types_float16: require 23#endif 24 25layout (constant_id = 0) const int across_spatial = 0; 26layout (constant_id = 1) const int across_channel = 0; 27 28#if NCNN_image_shader 29layout (binding = 0) uniform unfp sampler3D square_blob; 30layout (binding = 1, rgba32f) writeonly uniform highp image3D sqsum_blob; 31#else 32layout (binding = 0) readonly buffer bottom_top_blob { sfpvec8 bottom_top_blob_data[]; }; 33layout (binding = 1) writeonly buffer sqsum_blob { mat2x4 sqsum_blob_data[]; }; 34#endif 35 36layout (push_constant) uniform parameter 37{ 38 int w; 39 int h; 40 int c; 41 int cstep; 42 43 int outw; 44 int outh; 45 int outc; 46 int outcstep; 47} p; 48 49void main() 50{ 51 int gx = int(gl_GlobalInvocationID.x); 52 int gy = int(gl_GlobalInvocationID.y); 53 int gz = int(gl_GlobalInvocationID.z); 54 55 if (gx >= p.outw || gy >= p.outh || gz >= p.outc) 56 return; 57 58 mat2x4 sqsum; 59 60 if (across_spatial == 1 && across_channel == 1) 61 { 62#if NCNN_image_shader 63 int sz = gz * 2; 64 int sy = gy * 2; 65 int sx = gx * 2; 66 67 if (sz == p.c - 1) 68 { 69 if (sy == p.h - 1) 70 { 71 if (sx == p.w - 1) 72 { 73 mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, sz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, sz), 0)); 74 75 sqsum[0] = v0[0] * v0[0]; 76 sqsum[1] = v0[1] * v0[1]; 77 } 78 else 79 { 80 mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, sz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, sz), 0)); 81 mat2x4 v1 = mat2x4(texelFetch(square_blob, ivec3((sx + 1) * 2, sy, sz), 0), texelFetch(square_blob, ivec3((sx + 1) * 2 + 1, sy, sz), 0)); 82 83 sqsum[0] = v0[0] * v0[0] + v1[0] * v1[0]; 84 sqsum[1] = v0[1] * v0[1] + v1[1] * v1[1]; 85 } 86 } 87 else 88 { 89 if (sx == p.w - 1) 90 { 91 mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, sz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, sz), 0)); 92 mat2x4 v2 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy + 1, sz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy + 1, sz), 0)); 93 94 sqsum[0] = v0[0] * v0[0] + v2[0] * v2[0]; 95 sqsum[1] = v0[1] * v0[1] + v2[1] * v2[1]; 96 } 97 else 98 { 99 mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, sz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, sz), 0)); 100 mat2x4 v1 = mat2x4(texelFetch(square_blob, ivec3((sx + 1) * 2, sy, sz), 0), texelFetch(square_blob, ivec3((sx + 1) * 2 + 1, sy, sz), 0)); 101 mat2x4 v2 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy + 1, sz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy + 1, sz), 0)); 102 mat2x4 v3 = mat2x4(texelFetch(square_blob, ivec3((sx + 1) * 2, sy + 1, sz), 0), texelFetch(square_blob, ivec3((sx + 1) * 2 + 1, sy + 1, sz), 0)); 103 104 sqsum[0] = v0[0] * v0[0] + v1[0] * v1[0] + v2[0] * v2[0] + v3[0] * v3[0]; 105 sqsum[1] = v0[1] * v0[1] + v1[1] * v1[1] + v2[1] * v2[1] + v3[1] * v3[1]; 106 } 107 } 108 } 109 else 110 { 111 if (sy == p.h - 1) 112 { 113 if (sx == p.w - 1) 114 { 115 mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, sz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, sz), 0)); 116 mat2x4 v4 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, sz + 1), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, sz + 1), 0)); 117 118 sqsum[0] = v0[0] * v0[0] + v4[0] * v4[0]; 119 sqsum[1] = v0[1] * v0[1] + v4[1] * v4[1]; 120 } 121 else 122 { 123 mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, sz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, sz), 0)); 124 mat2x4 v1 = mat2x4(texelFetch(square_blob, ivec3((sx + 1) * 2, sy, sz), 0), texelFetch(square_blob, ivec3((sx + 1) * 2 + 1, sy, sz), 0)); 125 mat2x4 v4 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, sz + 1), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, sz + 1), 0)); 126 mat2x4 v5 = mat2x4(texelFetch(square_blob, ivec3((sx + 1) * 2, sy, sz + 1), 0), texelFetch(square_blob, ivec3((sx + 1) * 2 + 1, sy, sz + 1), 0)); 127 128 sqsum[0] = v0[0] * v0[0] + v1[0] * v1[0] + v4[0] * v4[0] + v5[0] * v5[0]; 129 sqsum[1] = v0[1] * v0[1] + v1[1] * v1[1] + v4[1] * v4[1] + v5[1] * v5[1]; 130 } 131 } 132 else 133 { 134 if (sx == p.w - 1) 135 { 136 mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, sz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, sz), 0)); 137 mat2x4 v2 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy + 1, sz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy + 1, sz), 0)); 138 mat2x4 v4 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, sz + 1), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, sz + 1), 0)); 139 mat2x4 v6 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy + 1, sz + 1), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy + 1, sz + 1), 0)); 140 141 sqsum[0] = v0[0] * v0[0] + v2[0] * v2[0] + v4[0] * v4[0] + v6[0] * v6[0]; 142 sqsum[1] = v0[1] * v0[1] + v2[1] * v2[1] + v4[1] * v4[1] + v6[1] * v6[1]; 143 } 144 else 145 { 146 mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, sz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, sz), 0)); 147 mat2x4 v1 = mat2x4(texelFetch(square_blob, ivec3((sx + 1) * 2, sy, sz), 0), texelFetch(square_blob, ivec3((sx + 1) * 2 + 1, sy, sz), 0)); 148 mat2x4 v2 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy + 1, sz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy + 1, sz), 0)); 149 mat2x4 v3 = mat2x4(texelFetch(square_blob, ivec3((sx + 1) * 2, sy + 1, sz), 0), texelFetch(square_blob, ivec3((sx + 1) * 2 + 1, sy + 1, sz), 0)); 150 mat2x4 v4 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, sz + 1), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, sz + 1), 0)); 151 mat2x4 v5 = mat2x4(texelFetch(square_blob, ivec3((sx + 1) * 2, sy, sz + 1), 0), texelFetch(square_blob, ivec3((sx + 1) * 2 + 1, sy, sz + 1), 0)); 152 mat2x4 v6 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy + 1, sz + 1), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy + 1, sz + 1), 0)); 153 mat2x4 v7 = mat2x4(texelFetch(square_blob, ivec3((sx + 1) * 2, sy + 1, sz + 1), 0), texelFetch(square_blob, ivec3((sx + 1) * 2 + 1, sy + 1, sz + 1), 0)); 154 155 sqsum[0] = v0[0] * v0[0] + v1[0] * v1[0] + v2[0] * v2[0] + v3[0] * v3[0] + v4[0] * v4[0] + v5[0] * v5[0] + v6[0] * v6[0] + v7[0] * v7[0]; 156 sqsum[1] = v0[1] * v0[1] + v1[1] * v1[1] + v2[1] * v2[1] + v3[1] * v3[1] + v4[1] * v4[1] + v5[1] * v5[1] + v6[1] * v6[1] + v7[1] * v7[1]; 157 } 158 } 159 } 160#else 161 int sz = gz * 2; 162 int sx = gx * 2; 163 164 ivec2 v_offset0 = sz * p.cstep + sx + ivec2(0, 1); 165 ivec2 v_offset1 = v_offset0 + p.cstep; 166 167 if (sz == p.c - 1) 168 { 169 if (sx == p.w - 1) 170 { 171 mat2x4 v0 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset0.r)); 172 173 sqsum[0] = v0[0] * v0[0]; 174 sqsum[1] = v0[1] * v0[1]; 175 } 176 else 177 { 178 mat2x4 v0 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset0.r)); 179 mat2x4 v1 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset0.g)); 180 181 sqsum[0] = v0[0] * v0[0] + v1[0] * v1[0]; 182 sqsum[1] = v0[1] * v0[1] + v1[1] * v1[1]; 183 } 184 } 185 else 186 { 187 if (sx == p.w - 1) 188 { 189 mat2x4 v0 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset0.r)); 190 mat2x4 v2 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset1.r)); 191 192 sqsum[0] = v0[0] * v0[0] + v2[0] * v2[0]; 193 sqsum[1] = v0[1] * v0[1] + v2[1] * v2[1]; 194 } 195 else 196 { 197 mat2x4 v0 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset0.r)); 198 mat2x4 v1 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset0.g)); 199 mat2x4 v2 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset1.r)); 200 mat2x4 v3 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset1.g)); 201 202 sqsum[0] = v0[0] * v0[0] + v1[0] * v1[0] + v2[0] * v2[0] + v3[0] * v3[0]; 203 sqsum[1] = v0[1] * v0[1] + v1[1] * v1[1] + v2[1] * v2[1] + v3[1] * v3[1]; 204 } 205 } 206#endif 207 } 208 209 if (across_spatial == 1 && across_channel == 0) 210 { 211#if NCNN_image_shader 212 int sy = gy * 2; 213 int sx = gx * 2; 214 215 if (sy == p.h - 1) 216 { 217 if (sx == p.w - 1) 218 { 219 mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, gz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, gz), 0)); 220 221 sqsum[0] = v0[0] * v0[0]; 222 sqsum[1] = v0[1] * v0[1]; 223 } 224 else 225 { 226 mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, gz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, gz), 0)); 227 mat2x4 v1 = mat2x4(texelFetch(square_blob, ivec3((sx + 1) * 2, sy, gz), 0), texelFetch(square_blob, ivec3((sx + 1) * 2 + 1, sy, gz), 0)); 228 229 sqsum[0] = v0[0] * v0[0] + v1[0] * v1[0]; 230 sqsum[1] = v0[1] * v0[1] + v1[1] * v1[1]; 231 } 232 } 233 else 234 { 235 if (sx == p.w - 1) 236 { 237 mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, gz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, gz), 0)); 238 mat2x4 v2 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy + 1, gz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy + 1, gz), 0)); 239 240 sqsum[0] = v0[0] * v0[0] + v2[0] * v2[0]; 241 sqsum[1] = v0[1] * v0[1] + v2[1] * v2[1]; 242 } 243 else 244 { 245 mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy, gz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy, gz), 0)); 246 mat2x4 v1 = mat2x4(texelFetch(square_blob, ivec3((sx + 1) * 2, sy, gz), 0), texelFetch(square_blob, ivec3((sx + 1) * 2 + 1, sy, gz), 0)); 247 mat2x4 v2 = mat2x4(texelFetch(square_blob, ivec3(sx * 2, sy + 1, gz), 0), texelFetch(square_blob, ivec3(sx * 2 + 1, sy + 1, gz), 0)); 248 mat2x4 v3 = mat2x4(texelFetch(square_blob, ivec3((sx + 1) * 2, sy + 1, gz), 0), texelFetch(square_blob, ivec3((sx + 1) * 2 + 1, sy + 1, gz), 0)); 249 250 sqsum[0] = v0[0] * v0[0] + v1[0] * v1[0] + v2[0] * v2[0] + v3[0] * v3[0]; 251 sqsum[1] = v0[1] * v0[1] + v1[1] * v1[1] + v2[1] * v2[1] + v3[1] * v3[1]; 252 } 253 } 254#else 255 int sz = gz; 256 int sx = gx * 4; 257 258 ivec4 v_offset = sz * p.cstep + sx + ivec4(0, 1, 2, 3); 259 260 if (sx == p.w - 1) 261 { 262 mat2x4 v0 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset.r)); 263 264 sqsum[0] = v0[0] * v0[0]; 265 sqsum[1] = v0[1] * v0[1]; 266 } 267 else if (sx == p.w - 2) 268 { 269 mat2x4 v0 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset.r)); 270 mat2x4 v1 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset.g)); 271 272 sqsum[0] = v0[0] * v0[0] + v1[0] * v1[0]; 273 sqsum[1] = v0[1] * v0[1] + v1[1] * v1[1]; 274 } 275 else if (sx == p.w - 3) 276 { 277 mat2x4 v0 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset.r)); 278 mat2x4 v1 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset.g)); 279 mat2x4 v2 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset.b)); 280 281 sqsum[0] = v0[0] * v0[0] + v1[0] * v1[0] + v2[0] * v2[0]; 282 sqsum[1] = v0[1] * v0[1] + v1[1] * v1[1] + v2[1] * v2[1]; 283 } 284 else 285 { 286 mat2x4 v0 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset.r)); 287 mat2x4 v1 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset.g)); 288 mat2x4 v2 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset.b)); 289 mat2x4 v3 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset.a)); 290 291 sqsum[0] = v0[0] * v0[0] + v1[0] * v1[0] + v2[0] * v2[0] + v3[0] * v3[0]; 292 sqsum[1] = v0[1] * v0[1] + v1[1] * v1[1] + v2[1] * v2[1] + v3[1] * v3[1]; 293 } 294#endif 295 } 296 297 if (across_spatial == 0 && across_channel == 1) 298 { 299#if NCNN_image_shader 300 int sz = gz * 4; 301 302 if (sz == p.c - 1) 303 { 304 mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(gx * 2, gy, sz), 0), texelFetch(square_blob, ivec3(gx * 2 + 1, gy, sz), 0)); 305 306 sqsum[0] = v0[0] * v0[0]; 307 sqsum[1] = v0[1] * v0[1]; 308 } 309 else if (sz == p.c - 2) 310 { 311 mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(gx * 2, gy, sz), 0), texelFetch(square_blob, ivec3(gx * 2 + 1, gy, sz), 0)); 312 mat2x4 v1 = mat2x4(texelFetch(square_blob, ivec3(gx * 2, gy, sz + 1), 0), texelFetch(square_blob, ivec3(gx * 2 + 1, gy, sz + 1), 0)); 313 314 sqsum[0] = v0[0] * v0[0] + v1[0] * v1[0]; 315 sqsum[1] = v0[1] * v0[1] + v1[1] * v1[1]; 316 } 317 else if (sz == p.c - 3) 318 { 319 mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(gx * 2, gy, sz), 0), texelFetch(square_blob, ivec3(gx * 2 + 1, gy, sz), 0)); 320 mat2x4 v1 = mat2x4(texelFetch(square_blob, ivec3(gx * 2, gy, sz + 1), 0), texelFetch(square_blob, ivec3(gx * 2 + 1, gy, sz + 1), 0)); 321 mat2x4 v2 = mat2x4(texelFetch(square_blob, ivec3(gx * 2, gy, sz + 2), 0), texelFetch(square_blob, ivec3(gx * 2 + 1, gy, sz + 2), 0)); 322 323 sqsum[0] = v0[0] * v0[0] + v1[0] * v1[0] + v2[0] * v2[0]; 324 sqsum[1] = v0[1] * v0[1] + v1[1] * v1[1] + v2[1] * v2[1]; 325 } 326 else 327 { 328 mat2x4 v0 = mat2x4(texelFetch(square_blob, ivec3(gx * 2, gy, sz), 0), texelFetch(square_blob, ivec3(gx * 2 + 1, gy, sz), 0)); 329 mat2x4 v1 = mat2x4(texelFetch(square_blob, ivec3(gx * 2, gy, sz + 1), 0), texelFetch(square_blob, ivec3(gx * 2 + 1, gy, sz + 1), 0)); 330 mat2x4 v2 = mat2x4(texelFetch(square_blob, ivec3(gx * 2, gy, sz + 2), 0), texelFetch(square_blob, ivec3(gx * 2 + 1, gy, sz + 2), 0)); 331 mat2x4 v3 = mat2x4(texelFetch(square_blob, ivec3(gx * 2, gy, sz + 3), 0), texelFetch(square_blob, ivec3(gx * 2 + 1, gy, sz + 3), 0)); 332 333 sqsum[0] = v0[0] * v0[0] + v1[0] * v1[0] + v2[0] * v2[0] + v3[0] * v3[0]; 334 sqsum[1] = v0[1] * v0[1] + v1[1] * v1[1] + v2[1] * v2[1] + v3[1] * v3[1]; 335 } 336#else 337 int sz = gz * 4; 338 int sx = gx; 339 340 ivec4 v_offset = (sz + ivec4(0, 1, 2, 3)) * p.cstep + sx; 341 342 if (sz == p.c - 1) 343 { 344 mat2x4 v0 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset.r)); 345 346 sqsum[0] = v0[0] * v0[0]; 347 sqsum[1] = v0[1] * v0[1]; 348 } 349 else if (sz == p.c - 2) 350 { 351 mat2x4 v0 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset.r)); 352 mat2x4 v1 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset.g)); 353 354 sqsum[0] = v0[0] * v0[0] + v1[0] * v1[0]; 355 sqsum[1] = v0[1] * v0[1] + v1[1] * v1[1]; 356 } 357 else if (sz == p.c - 3) 358 { 359 mat2x4 v0 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset.r)); 360 mat2x4 v1 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset.g)); 361 mat2x4 v2 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset.b)); 362 363 sqsum[0] = v0[0] * v0[0] + v1[0] * v1[0] + v2[0] * v2[0]; 364 sqsum[1] = v0[1] * v0[1] + v1[1] * v1[1] + v2[1] * v2[1]; 365 } 366 else 367 { 368 mat2x4 v0 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset.r)); 369 mat2x4 v1 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset.g)); 370 mat2x4 v2 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset.b)); 371 mat2x4 v3 = mat2x4(buffer_ld8(bottom_top_blob_data, v_offset.a)); 372 373 sqsum[0] = v0[0] * v0[0] + v1[0] * v1[0] + v2[0] * v2[0] + v3[0] * v3[0]; 374 sqsum[1] = v0[1] * v0[1] + v1[1] * v1[1] + v2[1] * v2[1] + v3[1] * v3[1]; 375 } 376#endif 377 } 378 379#if NCNN_image_shader 380 imageStore(sqsum_blob, ivec3(gx * 2, gy, gz), sqsum[0]); 381 imageStore(sqsum_blob, ivec3(gx * 2 + 1, gy, gz), sqsum[1]); 382#else 383 int gi = gz * p.outcstep + gx; 384 385 sqsum_blob_data[gi] = sqsum; 386#endif 387} 388