1// Tencent is pleased to support the open source community by making ncnn available. 2// 3// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved. 4// 5// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except 6// in compliance with the License. You may obtain a copy of the License at 7// 8// https://opensource.org/licenses/BSD-3-Clause 9// 10// Unless required by applicable law or agreed to in writing, software distributed 11// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 12// CONDITIONS OF ANY KIND, either express or implied. See the License for the 13// specific language governing permissions and limitations under the License. 14 15#version 450 16 17#if NCNN_fp16_storage 18#extension GL_EXT_shader_16bit_storage: require 19struct sfpvec8 { f16vec4 abcd; f16vec4 efgh; }; 20#endif 21#if NCNN_fp16_arithmetic 22#extension GL_EXT_shader_explicit_arithmetic_types_float16: require 23#endif 24 25layout (constant_id = 0) const int op_type = 0; 26 27#define shape_constant_id_offset 1 28layout (constant_id = shape_constant_id_offset + 0) const int adims = 0; 29layout (constant_id = shape_constant_id_offset + 1) const int aw = 0; 30layout (constant_id = shape_constant_id_offset + 2) const int ah = 0; 31layout (constant_id = shape_constant_id_offset + 3) const int ac = 0; 32layout (constant_id = shape_constant_id_offset + 4) const int acstep = 0; 33 34layout (constant_id = shape_constant_id_offset + 5) const int bdims = 0; 35layout (constant_id = shape_constant_id_offset + 6) const int bw = 0; 36layout (constant_id = shape_constant_id_offset + 7) const int bh = 0; 37layout (constant_id = shape_constant_id_offset + 8) const int bc = 0; 38layout (constant_id = shape_constant_id_offset + 9) const int bcstep = 0; 39 40layout (constant_id = shape_constant_id_offset + 10) const int outdims = 0; 41layout (constant_id = shape_constant_id_offset + 11) const int outw = 0; 42layout (constant_id = shape_constant_id_offset + 12) const int outh = 0; 43layout (constant_id = shape_constant_id_offset + 13) const int outc = 0; 44layout (constant_id = shape_constant_id_offset + 14) const int outcstep = 0; 45 46#if NCNN_image_shader 47layout (binding = 0) uniform unfp sampler3D a_blob_3d; 48layout (binding = 1) uniform unfp sampler3D b_blob_3d; 49layout (binding = 2, imfmtc4) writeonly uniform unfp image3D top_blob_3d; 50#else 51layout (binding = 0) readonly buffer a_blob { sfpvec8 a_blob_data[]; }; 52layout (binding = 1) readonly buffer b_blob { sfpvec8 b_blob_data[]; }; 53layout (binding = 2) writeonly buffer top_blob { sfpvec8 top_blob_data[]; }; 54#endif 55 56layout (push_constant) uniform parameter 57{ 58 int adims; 59 int aw; 60 int ah; 61 int ac; 62 int acstep; 63 64 int bdims; 65 int bw; 66 int bh; 67 int bc; 68 int bcstep; 69 70 int outdims; 71 int outw; 72 int outh; 73 int outc; 74 int outcstep; 75} p; 76 77void main() 78{ 79 int gx = int(gl_GlobalInvocationID.x); 80 int gy = int(gl_GlobalInvocationID.y); 81 int gz = int(gl_GlobalInvocationID.z); 82 83 if (gx >= psc(outw) || gy >= psc(outh) || gz >= psc(outc)) 84 return; 85 86#if NCNN_image_shader 87 int ax = gx; 88 int ay = gy; 89 int az = gz; 90 int bx = gx; 91 int by = gy; 92 int bz = gz; 93 94 if (psc(adims) == 3) 95 { 96 if (psc(bdims) == 3) 97 { 98 if (psc(bw) == 1 && psc(bh) == 1) 99 { 100 // special type 1 101 bx = 0; 102 by = 0; 103 } 104 105 if (psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(bc) == 1) 106 { 107 // special type 2 108 bz = 0; 109 } 110 111 if (psc(aw) == 1 && psc(ah) == 1) 112 { 113 // special type 3 114 ax = 0; 115 ay = 0; 116 } 117 118 if (psc(bw) == psc(aw) && psc(bh) == psc(ah) && psc(ac) == 1) 119 { 120 // special type 4 121 az = 0; 122 } 123 124 if (psc(aw) != 1 && psc(bw) == 1 && psc(bh) == psc(ah) && psc(bc) == psc(ac)) 125 { 126 // special type 5 127 bx = 0; 128 } 129 130 if (psc(bw) == psc(aw) && psc(ah) != 1 && psc(bh) == 1 && psc(bc) == psc(ac)) 131 { 132 // special type 6 133 by = 0; 134 } 135 136 if (psc(bw) != 1 && psc(aw) == 1 && psc(bh) == psc(ah) && psc(bc) == psc(ac)) 137 { 138 // special type 7 139 ax = 0; 140 } 141 142 if (psc(bw) == psc(aw) && psc(bh) != 1 && psc(ah) == 1 && psc(bc) == psc(ac)) 143 { 144 // special type 8 145 ay = 0; 146 } 147 } 148 149 if (psc(bdims) == 2) 150 { 151 // type 18 152 bx = gy; 153 by = gz; 154 bz = 0; 155 } 156 157 if (psc(bdims) == 1) 158 { 159 if (psc(bw) == 1) 160 { 161 // type 16 162 bx = 0; 163 by = 0; 164 bz = 0; 165 } 166 else 167 { 168 // type 17 169 bx = gz; 170 by = 0; 171 bz = 0; 172 } 173 } 174 } 175 else if (psc(adims) == 2) 176 { 177 if (psc(bdims) == 3) 178 { 179 // type 14 180 ax = gy; 181 ay = gz; 182 az = 0; 183 } 184 185 if (psc(bdims) == 1) 186 { 187 if (psc(bw) == 1) 188 { 189 // type 11 190 bx = 0; 191 by = 0; 192 bz = 0; 193 } 194 else 195 { 196 // type 12 197 bx = gy; 198 by = 0; 199 bz = 0; 200 } 201 } 202 } 203 else if (psc(adims) == 1) 204 { 205 if (psc(aw) == 1) 206 { 207 // type 2 3 4 208 ax = 0; 209 ay = 0; 210 az = 0; 211 } 212 else 213 { 214 if (psc(bdims) == 3) 215 { 216 // type 9 217 ax = gz; 218 ay = 0; 219 az = 0; 220 } 221 222 if (psc(bdims) == 2) 223 { 224 // type 8 225 ax = gy; 226 ay = 0; 227 az = 0; 228 } 229 230 if (psc(bdims) == 1) 231 { 232 if (psc(bw) == 1) 233 { 234 // type 6 235 bx = 0; 236 by = 0; 237 bz = 0; 238 } 239 } 240 } 241 } 242 243 afpvec8 v1 = image3d_ld8(a_blob_3d, ivec3(ax, ay, az)); 244 afpvec8 v2 = image3d_ld8(b_blob_3d, ivec3(bx, by, bz)); 245#else 246 const int gi = gz * psc(outcstep) + gy * psc(outw) + gx; 247 248 int ai; 249 int bi; 250 251 if (psc(adims) == 3) 252 { 253 if (psc(bdims) == 3) 254 { 255 if (psc(bw) == 1 && psc(bh) == 1) 256 { 257 // special type 1 258 ai = gi; 259 bi = gz * psc(bcstep); 260 } 261 262 if (psc(aw) == 1 && psc(ah) == 1) 263 { 264 // special type 3 265 ai = gz * psc(acstep); 266 bi = gi; 267 } 268 269 if (psc(aw) != 1 && psc(bw) == 1 && psc(bh) == psc(ah) && psc(bc) == psc(ac)) 270 { 271 // special type 5 272 bi = gz * psc(bcstep) + gy; 273 ai = gi; 274 } 275 276 if (psc(bw) == psc(aw) && psc(ah) != 1 && psc(bh) == 1 && psc(bc) == psc(ac)) 277 { 278 // special type 6 279 bi = gz * psc(bcstep) + gx; 280 ai = gi; 281 } 282 283 if (psc(bw) != 1 && psc(aw) == 1 && psc(bh) == psc(ah) && psc(bc) == psc(ac)) 284 { 285 // special type 7 286 ai = gz * psc(acstep) + gy; 287 bi = gi; 288 } 289 290 if (psc(bw) == psc(aw) && psc(bh) != 1 && psc(ah) == 1 && psc(bc) == psc(ac)) 291 { 292 // special type 8 293 ai = gz * psc(acstep) + gx; 294 bi = gi; 295 } 296 } 297 298 if (psc(bdims) == 2) 299 { 300 // type 18 301 ai = gi; 302 bi = gz * psc(bw) + gy; 303 } 304 305 if (psc(bdims) == 1) 306 { 307 // type 17 308 ai = gi; 309 bi = gz; 310 } 311 } 312 else if (psc(adims) == 2) 313 { 314 if (psc(bdims) == 3) 315 { 316 // type 14 317 ai = gz * psc(aw) + gy; 318 bi = gi; 319 } 320 321 if (psc(bdims) == 1) 322 { 323 // type 12 324 ai = gi; 325 bi = gy; 326 } 327 } 328 else if (psc(adims) == 1) 329 { 330 if (psc(bdims) == 3) 331 { 332 // type 9 333 ai = gz; 334 bi = gi; 335 } 336 337 if (psc(bdims) == 2) 338 { 339 // type 8 340 ai = gy; 341 bi = gi; 342 } 343 } 344 345 afpvec8 v1 = buffer_ld8(a_blob_data, ai); 346 afpvec8 v2 = buffer_ld8(b_blob_data, bi); 347#endif 348 349 afpvec8 res; 350 351 if (op_type == 0) 352 { 353 res[0] = v1[0] + v2[0]; 354 res[1] = v1[1] + v2[1]; 355 } 356 if (op_type == 1) 357 { 358 res[0] = v1[0] - v2[0]; 359 res[1] = v1[1] - v2[1]; 360 } 361 if (op_type == 2) 362 { 363 res[0] = v1[0] * v2[0]; 364 res[1] = v1[1] * v2[1]; 365 } 366 if (op_type == 3) 367 { 368 res[0] = v1[0] / v2[0]; 369 res[1] = v1[1] / v2[1]; 370 } 371 if (op_type == 4) 372 { 373 res[0] = max(v1[0], v2[0]); 374 res[1] = max(v1[1], v2[1]); 375 } 376 if (op_type == 5) 377 { 378 res[0] = min(v1[0], v2[0]); 379 res[1] = min(v1[1], v2[1]); 380 } 381 if (op_type == 6) 382 { 383 res[0] = pow(v1[0], v2[0]); 384 res[1] = pow(v1[1], v2[1]); 385 } 386 if (op_type == 7) 387 { 388 res[0] = v2[0] - v1[0]; 389 res[1] = v2[1] - v1[1]; 390 } 391 if (op_type == 8) 392 { 393 res[0] = v2[0] / v1[0]; 394 res[1] = v2[1] / v1[1]; 395 } 396 397#if NCNN_image_shader 398 image3d_st8(top_blob_3d, ivec3(gx, gy, gz), res); 399#else 400 buffer_st8(top_blob_data, gi, res); 401#endif 402} 403