1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
15 #include "concat_vulkan.h"
16 
17 #include "layer_shader_type.h"
18 
19 namespace ncnn {
20 
Concat_vulkan()21 Concat_vulkan::Concat_vulkan()
22 {
23     support_vulkan = true;
24     support_image_storage = true;
25 
26     pipeline_concat[0] = 0;
27     pipeline_concat[1] = 0;
28     pipeline_concat_pack4[0] = 0;
29     pipeline_concat_pack4[1] = 0;
30     pipeline_concat_pack4to1[0] = 0;
31     pipeline_concat_pack4to1[1] = 0;
32     pipeline_concat_pack8[0] = 0;
33     pipeline_concat_pack8[1] = 0;
34     pipeline_concat_pack8to4[0] = 0;
35     pipeline_concat_pack8to4[1] = 0;
36     pipeline_concat_pack8to1[0] = 0;
37     pipeline_concat_pack8to1[1] = 0;
38 }
39 
create_pipeline(const Option & _opt)40 int Concat_vulkan::create_pipeline(const Option& _opt)
41 {
42     Option opt = _opt;
43 
44     const Mat& shape = bottom_shapes.empty() ? Mat() : bottom_shapes[0];
45     const Mat& out_shape = top_shapes.empty() ? Mat() : top_shapes[0];
46     int positive_axis = axis < 0 ? shape.dims + axis : axis;
47 
48     int out_elempack = 1;
49     if (out_shape.dims == 1) out_elempack = opt.use_shader_pack8 && out_shape.w % 8 == 0 ? 8 : out_shape.w % 4 == 0 ? 4 : 1;
50     if (out_shape.dims == 2) out_elempack = opt.use_shader_pack8 && out_shape.h % 8 == 0 ? 8 : out_shape.h % 4 == 0 ? 4 : 1;
51     if (out_shape.dims == 3) out_elempack = opt.use_shader_pack8 && out_shape.c % 8 == 0 ? 8 : out_shape.c % 4 == 0 ? 4 : 1;
52 
53     int elempack = 1;
54     if (positive_axis == 0)
55     {
56         if (shape.dims == 1) elempack = opt.use_shader_pack8 && shape.w % 8 == 0 ? 8 : shape.w % 4 == 0 ? 4 : 1;
57         if (shape.dims == 2) elempack = opt.use_shader_pack8 && shape.h % 8 == 0 ? 8 : shape.h % 4 == 0 ? 4 : 1;
58         if (shape.dims == 3) elempack = opt.use_shader_pack8 && shape.c % 8 == 0 ? 8 : shape.c % 4 == 0 ? 4 : 1;
59 
60         for (size_t b = 1; b < bottom_shapes.size(); b++)
61         {
62             const Mat& shape1 = bottom_shapes[b];
63 
64             int elempack1 = 1;
65             if (shape1.dims == 1) elempack1 = opt.use_shader_pack8 && shape1.w % 8 == 0 ? 8 : shape1.w % 4 == 0 ? 4 : 1;
66             if (shape1.dims == 2) elempack1 = opt.use_shader_pack8 && shape1.h % 8 == 0 ? 8 : shape1.h % 4 == 0 ? 4 : 1;
67             if (shape1.dims == 3) elempack1 = opt.use_shader_pack8 && shape1.c % 8 == 0 ? 8 : shape1.c % 4 == 0 ? 4 : 1;
68 
69             elempack = std::min(elempack, elempack1);
70         }
71     }
72     else
73     {
74         elempack = out_elempack;
75     }
76 
77     size_t elemsize;
78     if (opt.use_fp16_storage)
79     {
80         elemsize = elempack * 2u;
81     }
82     else if (opt.use_fp16_packed)
83     {
84         elemsize = elempack == 1 ? 4u : elempack * 2u;
85     }
86     else
87     {
88         elemsize = elempack * 4u;
89     }
90 
91     Mat out_shape_unpacked;
92     if (out_shape.dims == 1) out_shape_unpacked = Mat(out_shape.w / elempack, (void*)0, elemsize, elempack);
93     if (out_shape.dims == 2) out_shape_unpacked = Mat(out_shape.w, out_shape.h / elempack, (void*)0, elemsize, elempack);
94     if (out_shape.dims == 3) out_shape_unpacked = Mat(out_shape.w, out_shape.h, out_shape.c / elempack, (void*)0, elemsize, elempack);
95 
96     if (!vkdev->shape_support_image_storage(out_shape_unpacked))
97     {
98         support_image_storage = false;
99         opt.use_image_storage = false;
100     }
101 
102     std::vector<vk_specialization_type> specializations(1 + 10);
103     specializations[0].i = axis;
104     specializations[1 + 0].i = 0; // TODO handle shape_packed for concat2
105     specializations[1 + 1].i = 0;
106     specializations[1 + 2].i = 0;
107     specializations[1 + 3].i = 0;
108     specializations[1 + 4].i = 0;
109     specializations[1 + 5].i = out_shape_unpacked.dims;
110     specializations[1 + 6].i = out_shape_unpacked.w;
111     specializations[1 + 7].i = out_shape_unpacked.h;
112     specializations[1 + 8].i = out_shape_unpacked.c;
113     specializations[1 + 9].i = out_shape_unpacked.cstep;
114 
115     Mat local_size_xyz; // TODO more precise group size guessed from out_shape_unpacked
116     if (out_shape_unpacked.dims == 1)
117     {
118         local_size_xyz.w = 64;
119         local_size_xyz.h = 1;
120         local_size_xyz.c = 1;
121     }
122     if (out_shape_unpacked.dims == 2)
123     {
124         local_size_xyz.w = 8;
125         local_size_xyz.h = 8;
126         local_size_xyz.c = 1;
127     }
128     if (out_shape_unpacked.dims == 3)
129     {
130         local_size_xyz.w = 4;
131         local_size_xyz.h = 4;
132         local_size_xyz.c = 4;
133     }
134 
135     // pack1
136     if (shape.dims == 0 || elempack == 1)
137     {
138         pipeline_concat[0] = new Pipeline(vkdev);
139         pipeline_concat[0]->set_optimal_local_size_xyz(local_size_xyz);
140         pipeline_concat[0]->create(LayerShaderType::concat, opt, specializations);
141         pipeline_concat[1] = new Pipeline(vkdev);
142         pipeline_concat[1]->set_optimal_local_size_xyz(local_size_xyz);
143         pipeline_concat[1]->create(LayerShaderType::concat, opt, specializations);
144     }
145 
146     // pack4
147     if (shape.dims == 0 || elempack == 4)
148     {
149         pipeline_concat_pack4[0] = new Pipeline(vkdev);
150         pipeline_concat_pack4[0]->set_optimal_local_size_xyz(local_size_xyz);
151         pipeline_concat_pack4[0]->create(LayerShaderType::concat_pack4, opt, specializations);
152         pipeline_concat_pack4[1] = new Pipeline(vkdev);
153         pipeline_concat_pack4[1]->set_optimal_local_size_xyz(local_size_xyz);
154         pipeline_concat_pack4[1]->create(LayerShaderType::concat_pack4, opt, specializations);
155     }
156 
157     // pack4to1
158     if ((positive_axis <= 0 && shape.dims == 0) || elempack == 1)
159     {
160         pipeline_concat_pack4to1[0] = new Pipeline(vkdev);
161         pipeline_concat_pack4to1[0]->set_optimal_local_size_xyz(local_size_xyz);
162         pipeline_concat_pack4to1[0]->create(LayerShaderType::concat_pack4to1, opt, specializations);
163         pipeline_concat_pack4to1[1] = new Pipeline(vkdev);
164         pipeline_concat_pack4to1[1]->set_optimal_local_size_xyz(local_size_xyz);
165         pipeline_concat_pack4to1[1]->create(LayerShaderType::concat_pack4to1, opt, specializations);
166     }
167 
168     // pack8
169     if (opt.use_shader_pack8 && (shape.dims == 0 || elempack == 8))
170     {
171         pipeline_concat_pack8[0] = new Pipeline(vkdev);
172         pipeline_concat_pack8[0]->set_optimal_local_size_xyz(local_size_xyz);
173         pipeline_concat_pack8[0]->create(LayerShaderType::concat_pack8, opt, specializations);
174         pipeline_concat_pack8[1] = new Pipeline(vkdev);
175         pipeline_concat_pack8[1]->set_optimal_local_size_xyz(local_size_xyz);
176         pipeline_concat_pack8[1]->create(LayerShaderType::concat_pack8, opt, specializations);
177     }
178 
179     // pack8to4
180     if (opt.use_shader_pack8 && ((positive_axis <= 0 && shape.dims == 0) || elempack == 4))
181     {
182         pipeline_concat_pack8to4[0] = new Pipeline(vkdev);
183         pipeline_concat_pack8to4[0]->set_optimal_local_size_xyz(local_size_xyz);
184         pipeline_concat_pack8to4[0]->create(LayerShaderType::concat_pack8to4, opt, specializations);
185         pipeline_concat_pack8to4[1] = new Pipeline(vkdev);
186         pipeline_concat_pack8to4[1]->set_optimal_local_size_xyz(local_size_xyz);
187         pipeline_concat_pack8to4[1]->create(LayerShaderType::concat_pack8to4, opt, specializations);
188     }
189 
190     // pack8to1
191     if (opt.use_shader_pack8 && ((positive_axis <= 0 && shape.dims == 0) || elempack == 1))
192     {
193         pipeline_concat_pack8to1[0] = new Pipeline(vkdev);
194         pipeline_concat_pack8to1[0]->set_optimal_local_size_xyz(local_size_xyz);
195         pipeline_concat_pack8to1[0]->create(LayerShaderType::concat_pack8to1, opt, specializations);
196         pipeline_concat_pack8to1[1] = new Pipeline(vkdev);
197         pipeline_concat_pack8to1[1]->set_optimal_local_size_xyz(local_size_xyz);
198         pipeline_concat_pack8to1[1]->create(LayerShaderType::concat_pack8to1, opt, specializations);
199     }
200 
201     return 0;
202 }
203 
destroy_pipeline(const Option &)204 int Concat_vulkan::destroy_pipeline(const Option& /*opt*/)
205 {
206     delete pipeline_concat[0];
207     delete pipeline_concat[1];
208     pipeline_concat[0] = 0;
209     pipeline_concat[1] = 0;
210 
211     delete pipeline_concat_pack4[0];
212     delete pipeline_concat_pack4[1];
213     pipeline_concat_pack4[0] = 0;
214     pipeline_concat_pack4[1] = 0;
215 
216     delete pipeline_concat_pack4to1[0];
217     delete pipeline_concat_pack4to1[1];
218     pipeline_concat_pack4to1[0] = 0;
219     pipeline_concat_pack4to1[1] = 0;
220 
221     delete pipeline_concat_pack8[0];
222     delete pipeline_concat_pack8[1];
223     pipeline_concat_pack8[0] = 0;
224     pipeline_concat_pack8[1] = 0;
225 
226     delete pipeline_concat_pack8to4[0];
227     delete pipeline_concat_pack8to4[1];
228     pipeline_concat_pack8to4[0] = 0;
229     pipeline_concat_pack8to4[1] = 0;
230 
231     delete pipeline_concat_pack8to1[0];
232     delete pipeline_concat_pack8to1[1];
233     pipeline_concat_pack8to1[0] = 0;
234     pipeline_concat_pack8to1[1] = 0;
235 
236     return 0;
237 }
238 
forward(const std::vector<VkMat> & bottom_blobs,std::vector<VkMat> & top_blobs,VkCompute & cmd,const Option & opt) const239 int Concat_vulkan::forward(const std::vector<VkMat>& bottom_blobs, std::vector<VkMat>& top_blobs, VkCompute& cmd, const Option& opt) const
240 {
241     int dims = bottom_blobs[0].dims;
242     int positive_axis = axis < 0 ? dims + axis : axis;
243 
244     if (dims == 1) // positive_axis == 0
245     {
246         // concat vector
247         // total length
248         size_t elemsize = bottom_blobs[0].elemsize;
249         int elempack = bottom_blobs[0].elempack;
250         int top_w = 0;
251         for (size_t b = 0; b < bottom_blobs.size(); b++)
252         {
253             const VkMat& bottom_blob = bottom_blobs[b];
254             elemsize = std::min(elemsize, bottom_blob.elemsize);
255             elempack = std::min(elempack, bottom_blob.elempack);
256             top_w += bottom_blob.w * bottom_blob.elempack;
257         }
258 
259         int out_elempack = opt.use_shader_pack8 && top_w % 8 == 0 ? 8 : top_w % 4 == 0 ? 4 : 1;
260         size_t out_elemsize = elemsize / elempack * out_elempack;
261 
262         if (opt.use_fp16_packed && !opt.use_fp16_storage)
263         {
264             if (out_elempack == 8) out_elemsize = 8 * 2u;
265             if (out_elempack == 4) out_elemsize = 4 * 2u;
266             if (out_elempack == 1) out_elemsize = 4u;
267         }
268 
269         VkMat& top_blob = top_blobs[0];
270         top_blob.create(top_w / out_elempack, out_elemsize, out_elempack, opt.blob_vkallocator);
271         if (top_blob.empty())
272             return -100;
273 
274         VkMat top_blob_unpacked = top_blob;
275         if (elempack < out_elempack)
276         {
277             top_blob_unpacked.create(top_w / elempack, elemsize, elempack, opt.workspace_vkallocator);
278             if (top_blob_unpacked.empty())
279                 return -100;
280         }
281 
282         int woffset = 0;
283         for (size_t b = 0; b < bottom_blobs.size(); b++)
284         {
285             const VkMat& bottom_blob = bottom_blobs[b];
286 
287             std::vector<VkMat> bindings(2);
288             bindings[0] = bottom_blob;
289             bindings[1] = top_blob_unpacked;
290 
291             std::vector<vk_constant_type> constants(11);
292             constants[0].i = bottom_blob.dims;
293             constants[1].i = bottom_blob.w;
294             constants[2].i = bottom_blob.h;
295             constants[3].i = bottom_blob.c;
296             constants[4].i = bottom_blob.cstep;
297             constants[5].i = top_blob_unpacked.dims;
298             constants[6].i = top_blob_unpacked.w;
299             constants[7].i = top_blob_unpacked.h;
300             constants[8].i = top_blob_unpacked.c;
301             constants[9].i = top_blob_unpacked.cstep;
302             constants[10].i = woffset;
303 
304             const Pipeline* pipeline = 0;
305             if (bottom_blob.elempack == 1 && elempack == 1)
306             {
307                 pipeline = pipeline_concat[b % 2];
308             }
309             else if (bottom_blob.elempack == 4 && elempack == 4)
310             {
311                 pipeline = pipeline_concat_pack4[b % 2];
312             }
313             else if (bottom_blob.elempack == 4 && elempack == 1)
314             {
315                 pipeline = pipeline_concat_pack4to1[b % 2];
316             }
317             else if (bottom_blob.elempack == 8 && elempack == 8)
318             {
319                 pipeline = pipeline_concat_pack8[b % 2];
320             }
321             else if (bottom_blob.elempack == 8 && elempack == 4)
322             {
323                 pipeline = pipeline_concat_pack8to4[b % 2];
324             }
325             else if (bottom_blob.elempack == 8 && elempack == 1)
326             {
327                 pipeline = pipeline_concat_pack8to1[b % 2];
328             }
329 
330             cmd.record_pipeline(pipeline, bindings, constants, bottom_blob);
331 
332             woffset += bottom_blob.w * bottom_blob.elempack / elempack;
333         }
334 
335         // packing
336         if (elempack < out_elempack)
337         {
338             vkdev->convert_packing(top_blob_unpacked, top_blob, out_elempack, cmd, opt);
339         }
340 
341         return 0;
342     }
343 
344     if (dims == 2 && positive_axis == 0)
345     {
346         // concat image
347         int w = bottom_blobs[0].w;
348 
349         // total height
350         size_t elemsize = bottom_blobs[0].elemsize;
351         int elempack = bottom_blobs[0].elempack;
352         int top_h = 0;
353         for (size_t b = 0; b < bottom_blobs.size(); b++)
354         {
355             const VkMat& bottom_blob = bottom_blobs[b];
356             elemsize = std::min(elemsize, bottom_blob.elemsize);
357             elempack = std::min(elempack, bottom_blob.elempack);
358             top_h += bottom_blob.h * bottom_blob.elempack;
359         }
360 
361         int out_elempack = opt.use_shader_pack8 && top_h % 8 == 0 ? 8 : top_h % 4 == 0 ? 4 : 1;
362         size_t out_elemsize = elemsize / elempack * out_elempack;
363 
364         if (opt.use_fp16_packed && !opt.use_fp16_storage)
365         {
366             if (out_elempack == 8) out_elemsize = 8 * 2u;
367             if (out_elempack == 4) out_elemsize = 4 * 2u;
368             if (out_elempack == 1) out_elemsize = 4u;
369         }
370 
371         VkMat& top_blob = top_blobs[0];
372         top_blob.create(w, top_h / out_elempack, out_elemsize, out_elempack, opt.blob_vkallocator);
373         if (top_blob.empty())
374             return -100;
375 
376         VkMat top_blob_unpacked = top_blob;
377         if (elempack < out_elempack)
378         {
379             top_blob_unpacked.create(w, top_h / elempack, elemsize, elempack, opt.workspace_vkallocator);
380             if (top_blob_unpacked.empty())
381                 return -100;
382         }
383 
384         int hoffset = 0;
385         for (size_t b = 0; b < bottom_blobs.size(); b++)
386         {
387             const VkMat& bottom_blob = bottom_blobs[b];
388 
389             std::vector<VkMat> bindings(2);
390             bindings[0] = bottom_blob;
391             bindings[1] = top_blob_unpacked;
392 
393             std::vector<vk_constant_type> constants(11);
394             constants[0].i = bottom_blob.dims;
395             constants[1].i = bottom_blob.w;
396             constants[2].i = bottom_blob.h;
397             constants[3].i = bottom_blob.c;
398             constants[4].i = bottom_blob.cstep;
399             constants[5].i = top_blob_unpacked.dims;
400             constants[6].i = top_blob_unpacked.w;
401             constants[7].i = top_blob_unpacked.h;
402             constants[8].i = top_blob_unpacked.c;
403             constants[9].i = top_blob_unpacked.cstep;
404             constants[10].i = hoffset;
405 
406             const Pipeline* pipeline = 0;
407             if (bottom_blob.elempack == 1 && elempack == 1)
408             {
409                 pipeline = pipeline_concat[b % 2];
410             }
411             else if (bottom_blob.elempack == 4 && elempack == 4)
412             {
413                 pipeline = pipeline_concat_pack4[b % 2];
414             }
415             else if (bottom_blob.elempack == 4 && elempack == 1)
416             {
417                 pipeline = pipeline_concat_pack4to1[b % 2];
418             }
419             else if (bottom_blob.elempack == 8 && elempack == 8)
420             {
421                 pipeline = pipeline_concat_pack8[b % 2];
422             }
423             else if (bottom_blob.elempack == 8 && elempack == 4)
424             {
425                 pipeline = pipeline_concat_pack8to4[b % 2];
426             }
427             else if (bottom_blob.elempack == 8 && elempack == 1)
428             {
429                 pipeline = pipeline_concat_pack8to1[b % 2];
430             }
431 
432             cmd.record_pipeline(pipeline, bindings, constants, bottom_blob);
433 
434             hoffset += bottom_blob.h * bottom_blob.elempack / elempack;
435         }
436 
437         // packing
438         if (elempack < out_elempack)
439         {
440             vkdev->convert_packing(top_blob_unpacked, top_blob, out_elempack, cmd, opt);
441         }
442 
443         return 0;
444     }
445 
446     if (dims == 2 && positive_axis == 1)
447     {
448         // interleave image row
449         int h = bottom_blobs[0].h;
450         size_t elemsize = bottom_blobs[0].elemsize;
451         int elempack = bottom_blobs[0].elempack;
452 
453         // total width
454         int top_w = 0;
455         for (size_t b = 0; b < bottom_blobs.size(); b++)
456         {
457             const VkMat& bottom_blob = bottom_blobs[b];
458             top_w += bottom_blob.w;
459         }
460 
461         VkMat& top_blob = top_blobs[0];
462         top_blob.create(top_w, h, elemsize, elempack, opt.blob_vkallocator);
463         if (top_blob.empty())
464             return -100;
465 
466         int woffset = 0;
467         for (size_t b = 0; b < bottom_blobs.size(); b++)
468         {
469             const VkMat& bottom_blob = bottom_blobs[b];
470 
471             std::vector<VkMat> bindings(2);
472             bindings[0] = bottom_blob;
473             bindings[1] = top_blob;
474 
475             std::vector<vk_constant_type> constants(11);
476             constants[0].i = bottom_blob.dims;
477             constants[1].i = bottom_blob.w;
478             constants[2].i = bottom_blob.h;
479             constants[3].i = bottom_blob.c;
480             constants[4].i = bottom_blob.cstep;
481             constants[5].i = top_blob.dims;
482             constants[6].i = top_blob.w;
483             constants[7].i = top_blob.h;
484             constants[8].i = top_blob.c;
485             constants[9].i = top_blob.cstep;
486             constants[10].i = woffset;
487 
488             const Pipeline* pipeline = elempack == 8 ? pipeline_concat_pack8[b % 2]
489                                        : elempack == 4 ? pipeline_concat_pack4[b % 2]
490                                        : pipeline_concat[b % 2];
491 
492             cmd.record_pipeline(pipeline, bindings, constants, bottom_blob);
493 
494             woffset += bottom_blob.w;
495         }
496 
497         return 0;
498     }
499 
500     if (dims == 3 && positive_axis == 0)
501     {
502         // concat dim
503         int w = bottom_blobs[0].w;
504         int h = bottom_blobs[0].h;
505 
506         // total channels
507         size_t elemsize = bottom_blobs[0].elemsize;
508         int elempack = bottom_blobs[0].elempack;
509         int top_channels = 0;
510         for (size_t b = 0; b < bottom_blobs.size(); b++)
511         {
512             const VkMat& bottom_blob = bottom_blobs[b];
513             elemsize = std::min(elemsize, bottom_blob.elemsize);
514             elempack = std::min(elempack, bottom_blob.elempack);
515             top_channels += bottom_blob.c * bottom_blob.elempack;
516         }
517 
518         int out_elempack = opt.use_shader_pack8 && top_channels % 8 == 0 ? 8 : top_channels % 4 == 0 ? 4 : 1;
519         size_t out_elemsize = elemsize / elempack * out_elempack;
520 
521         if (opt.use_fp16_packed && !opt.use_fp16_storage)
522         {
523             if (out_elempack == 8) out_elemsize = 8 * 2u;
524             if (out_elempack == 4) out_elemsize = 4 * 2u;
525             if (out_elempack == 1) out_elemsize = 4u;
526         }
527 
528         VkMat& top_blob = top_blobs[0];
529         top_blob.create(w, h, top_channels / out_elempack, out_elemsize, out_elempack, opt.blob_vkallocator);
530         if (top_blob.empty())
531             return -100;
532 
533         VkMat top_blob_unpacked = top_blob;
534         if (elempack < out_elempack)
535         {
536             top_blob_unpacked.create(w, h, top_channels / elempack, elemsize, elempack, opt.workspace_vkallocator);
537             if (top_blob_unpacked.empty())
538                 return -100;
539         }
540 
541         int coffset = 0;
542         for (size_t b = 0; b < bottom_blobs.size(); b++)
543         {
544             const VkMat& bottom_blob = bottom_blobs[b];
545 
546             std::vector<VkMat> bindings(2);
547             bindings[0] = bottom_blob;
548             bindings[1] = top_blob_unpacked;
549 
550             std::vector<vk_constant_type> constants(11);
551             constants[0].i = bottom_blob.dims;
552             constants[1].i = bottom_blob.w;
553             constants[2].i = bottom_blob.h;
554             constants[3].i = bottom_blob.c;
555             constants[4].i = bottom_blob.cstep;
556             constants[5].i = top_blob_unpacked.dims;
557             constants[6].i = top_blob_unpacked.w;
558             constants[7].i = top_blob_unpacked.h;
559             constants[8].i = top_blob_unpacked.c;
560             constants[9].i = top_blob_unpacked.cstep;
561             constants[10].i = coffset;
562 
563             const Pipeline* pipeline = 0;
564             if (bottom_blob.elempack == 1 && elempack == 1)
565             {
566                 pipeline = pipeline_concat[b % 2];
567             }
568             else if (bottom_blob.elempack == 4 && elempack == 4)
569             {
570                 pipeline = pipeline_concat_pack4[b % 2];
571             }
572             else if (bottom_blob.elempack == 4 && elempack == 1)
573             {
574                 pipeline = pipeline_concat_pack4to1[b % 2];
575             }
576             else if (bottom_blob.elempack == 8 && elempack == 8)
577             {
578                 pipeline = pipeline_concat_pack8[b % 2];
579             }
580             else if (bottom_blob.elempack == 8 && elempack == 4)
581             {
582                 pipeline = pipeline_concat_pack8to4[b % 2];
583             }
584             else if (bottom_blob.elempack == 8 && elempack == 1)
585             {
586                 pipeline = pipeline_concat_pack8to1[b % 2];
587             }
588 
589             cmd.record_pipeline(pipeline, bindings, constants, bottom_blob);
590 
591             coffset += bottom_blob.c * bottom_blob.elempack / elempack;
592         }
593 
594         // packing
595         if (elempack < out_elempack)
596         {
597             vkdev->convert_packing(top_blob_unpacked, top_blob, out_elempack, cmd, opt);
598         }
599 
600         return 0;
601     }
602 
603     if (dims == 3 && positive_axis == 1)
604     {
605         // interleave dim height
606         int w = bottom_blobs[0].w;
607         int channels = bottom_blobs[0].c;
608         size_t elemsize = bottom_blobs[0].elemsize;
609         int elempack = bottom_blobs[0].elempack;
610 
611         // total height
612         int top_h = 0;
613         for (size_t b = 0; b < bottom_blobs.size(); b++)
614         {
615             const VkMat& bottom_blob = bottom_blobs[b];
616             top_h += bottom_blob.h;
617         }
618 
619         VkMat& top_blob = top_blobs[0];
620         top_blob.create(w, top_h, channels, elemsize, elempack, opt.blob_vkallocator);
621         if (top_blob.empty())
622             return -100;
623 
624         int hoffset = 0;
625         for (size_t b = 0; b < bottom_blobs.size(); b++)
626         {
627             const VkMat& bottom_blob = bottom_blobs[b];
628 
629             std::vector<VkMat> bindings(2);
630             bindings[0] = bottom_blob;
631             bindings[1] = top_blob;
632 
633             std::vector<vk_constant_type> constants(11);
634             constants[0].i = bottom_blob.dims;
635             constants[1].i = bottom_blob.w;
636             constants[2].i = bottom_blob.h;
637             constants[3].i = bottom_blob.c;
638             constants[4].i = bottom_blob.cstep;
639             constants[5].i = top_blob.dims;
640             constants[6].i = top_blob.w;
641             constants[7].i = top_blob.h;
642             constants[8].i = top_blob.c;
643             constants[9].i = top_blob.cstep;
644             constants[10].i = hoffset;
645 
646             const Pipeline* pipeline = elempack == 8 ? pipeline_concat_pack8[b % 2]
647                                        : elempack == 4 ? pipeline_concat_pack4[b % 2]
648                                        : pipeline_concat[b % 2];
649 
650             cmd.record_pipeline(pipeline, bindings, constants, bottom_blob);
651 
652             hoffset += bottom_blob.h;
653         }
654 
655         return 0;
656     }
657 
658     if (dims == 3 && positive_axis == 2)
659     {
660         // interleave dim width
661         int h = bottom_blobs[0].h;
662         int channels = bottom_blobs[0].c;
663         size_t elemsize = bottom_blobs[0].elemsize;
664         int elempack = bottom_blobs[0].elempack;
665 
666         // total height
667         int top_w = 0;
668         for (size_t b = 0; b < bottom_blobs.size(); b++)
669         {
670             const VkMat& bottom_blob = bottom_blobs[b];
671             top_w += bottom_blob.w;
672         }
673 
674         VkMat& top_blob = top_blobs[0];
675         top_blob.create(top_w, h, channels, elemsize, elempack, opt.blob_vkallocator);
676         if (top_blob.empty())
677             return -100;
678 
679         int woffset = 0;
680         for (size_t b = 0; b < bottom_blobs.size(); b++)
681         {
682             const VkMat& bottom_blob = bottom_blobs[b];
683 
684             std::vector<VkMat> bindings(2);
685             bindings[0] = bottom_blob;
686             bindings[1] = top_blob;
687 
688             std::vector<vk_constant_type> constants(11);
689             constants[0].i = bottom_blob.dims;
690             constants[1].i = bottom_blob.w;
691             constants[2].i = bottom_blob.h;
692             constants[3].i = bottom_blob.c;
693             constants[4].i = bottom_blob.cstep;
694             constants[5].i = top_blob.dims;
695             constants[6].i = top_blob.w;
696             constants[7].i = top_blob.h;
697             constants[8].i = top_blob.c;
698             constants[9].i = top_blob.cstep;
699             constants[10].i = woffset;
700 
701             const Pipeline* pipeline = elempack == 8 ? pipeline_concat_pack8[b % 2]
702                                        : elempack == 4 ? pipeline_concat_pack4[b % 2]
703                                        : pipeline_concat[b % 2];
704 
705             cmd.record_pipeline(pipeline, bindings, constants, bottom_blob);
706 
707             woffset += bottom_blob.w;
708         }
709 
710         return 0;
711     }
712 
713     return 0;
714 }
715 
forward(const std::vector<VkImageMat> & bottom_blobs,std::vector<VkImageMat> & top_blobs,VkCompute & cmd,const Option & opt) const716 int Concat_vulkan::forward(const std::vector<VkImageMat>& bottom_blobs, std::vector<VkImageMat>& top_blobs, VkCompute& cmd, const Option& opt) const
717 {
718     int dims = bottom_blobs[0].dims;
719     int positive_axis = axis < 0 ? dims + axis : axis;
720 
721     if (dims == 1) // positive_axis == 0
722     {
723         // concat vector
724         // total length
725         size_t elemsize = bottom_blobs[0].elemsize;
726         int elempack = bottom_blobs[0].elempack;
727         int top_w = 0;
728         for (size_t b = 0; b < bottom_blobs.size(); b++)
729         {
730             const VkImageMat& bottom_blob = bottom_blobs[b];
731             elemsize = std::min(elemsize, bottom_blob.elemsize);
732             elempack = std::min(elempack, bottom_blob.elempack);
733             top_w += bottom_blob.w * bottom_blob.elempack;
734         }
735 
736         int out_elempack = opt.use_shader_pack8 && top_w % 8 == 0 ? 8 : top_w % 4 == 0 ? 4 : 1;
737         size_t out_elemsize = elemsize / elempack * out_elempack;
738 
739         if (opt.use_fp16_packed && !opt.use_fp16_storage)
740         {
741             if (out_elempack == 8) out_elemsize = 8 * 2u;
742             if (out_elempack == 4) out_elemsize = 4 * 2u;
743             if (out_elempack == 1) out_elemsize = 4u;
744         }
745 
746         VkImageMat& top_blob = top_blobs[0];
747         top_blob.create(top_w / out_elempack, out_elemsize, out_elempack, opt.blob_vkallocator);
748         if (top_blob.empty())
749             return -100;
750 
751         VkImageMat top_blob_unpacked = top_blob;
752         if (elempack < out_elempack)
753         {
754             top_blob_unpacked.create(top_w / elempack, elemsize, elempack, opt.workspace_vkallocator);
755             if (top_blob_unpacked.empty())
756                 return -100;
757         }
758 
759         int woffset = 0;
760         for (size_t b = 0; b < bottom_blobs.size(); b++)
761         {
762             const VkImageMat& bottom_blob = bottom_blobs[b];
763 
764             std::vector<VkImageMat> bindings(2);
765             bindings[0] = bottom_blob;
766             bindings[1] = top_blob_unpacked;
767 
768             std::vector<vk_constant_type> constants(11);
769             constants[0].i = bottom_blob.dims;
770             constants[1].i = bottom_blob.w;
771             constants[2].i = bottom_blob.h;
772             constants[3].i = bottom_blob.c;
773             constants[4].i = 0; //bottom_blob.cstep;
774             constants[5].i = top_blob_unpacked.dims;
775             constants[6].i = top_blob_unpacked.w;
776             constants[7].i = top_blob_unpacked.h;
777             constants[8].i = top_blob_unpacked.c;
778             constants[9].i = 0; //top_blob_unpacked.cstep;
779             constants[10].i = woffset;
780 
781             const Pipeline* pipeline = 0;
782             if (bottom_blob.elempack == 1 && elempack == 1)
783             {
784                 pipeline = pipeline_concat[b % 2];
785             }
786             else if (bottom_blob.elempack == 4 && elempack == 4)
787             {
788                 pipeline = pipeline_concat_pack4[b % 2];
789             }
790             else if (bottom_blob.elempack == 4 && elempack == 1)
791             {
792                 pipeline = pipeline_concat_pack4to1[b % 2];
793             }
794             else if (bottom_blob.elempack == 8 && elempack == 8)
795             {
796                 pipeline = pipeline_concat_pack8[b % 2];
797             }
798             else if (bottom_blob.elempack == 8 && elempack == 4)
799             {
800                 pipeline = pipeline_concat_pack8to4[b % 2];
801             }
802             else if (bottom_blob.elempack == 8 && elempack == 1)
803             {
804                 pipeline = pipeline_concat_pack8to1[b % 2];
805             }
806 
807             cmd.record_pipeline(pipeline, bindings, constants, bottom_blob);
808 
809             woffset += bottom_blob.w * bottom_blob.elempack / elempack;
810         }
811 
812         // packing
813         if (elempack < out_elempack)
814         {
815             vkdev->convert_packing(top_blob_unpacked, top_blob, out_elempack, cmd, opt);
816         }
817 
818         return 0;
819     }
820 
821     if (dims == 2 && positive_axis == 0)
822     {
823         // concat image
824         int w = bottom_blobs[0].w;
825 
826         // total height
827         size_t elemsize = bottom_blobs[0].elemsize;
828         int elempack = bottom_blobs[0].elempack;
829         int top_h = 0;
830         for (size_t b = 0; b < bottom_blobs.size(); b++)
831         {
832             const VkImageMat& bottom_blob = bottom_blobs[b];
833             elemsize = std::min(elemsize, bottom_blob.elemsize);
834             elempack = std::min(elempack, bottom_blob.elempack);
835             top_h += bottom_blob.h * bottom_blob.elempack;
836         }
837 
838         int out_elempack = opt.use_shader_pack8 && top_h % 8 == 0 ? 8 : top_h % 4 == 0 ? 4 : 1;
839         size_t out_elemsize = elemsize / elempack * out_elempack;
840 
841         if (opt.use_fp16_packed && !opt.use_fp16_storage)
842         {
843             if (out_elempack == 8) out_elemsize = 8 * 2u;
844             if (out_elempack == 4) out_elemsize = 4 * 2u;
845             if (out_elempack == 1) out_elemsize = 4u;
846         }
847 
848         VkImageMat& top_blob = top_blobs[0];
849         top_blob.create(w, top_h / out_elempack, out_elemsize, out_elempack, opt.blob_vkallocator);
850         if (top_blob.empty())
851             return -100;
852 
853         VkImageMat top_blob_unpacked = top_blob;
854         if (elempack < out_elempack)
855         {
856             top_blob_unpacked.create(w, top_h / elempack, elemsize, elempack, opt.workspace_vkallocator);
857             if (top_blob_unpacked.empty())
858                 return -100;
859         }
860 
861         int hoffset = 0;
862         for (size_t b = 0; b < bottom_blobs.size(); b++)
863         {
864             const VkImageMat& bottom_blob = bottom_blobs[b];
865 
866             std::vector<VkImageMat> bindings(2);
867             bindings[0] = bottom_blob;
868             bindings[1] = top_blob_unpacked;
869 
870             std::vector<vk_constant_type> constants(11);
871             constants[0].i = bottom_blob.dims;
872             constants[1].i = bottom_blob.w;
873             constants[2].i = bottom_blob.h;
874             constants[3].i = bottom_blob.c;
875             constants[4].i = 0; //bottom_blob.cstep;
876             constants[5].i = top_blob_unpacked.dims;
877             constants[6].i = top_blob_unpacked.w;
878             constants[7].i = top_blob_unpacked.h;
879             constants[8].i = top_blob_unpacked.c;
880             constants[9].i = 0; //top_blob_unpacked.cstep;
881             constants[10].i = hoffset;
882 
883             const Pipeline* pipeline = 0;
884             if (bottom_blob.elempack == 1 && elempack == 1)
885             {
886                 pipeline = pipeline_concat[b % 2];
887             }
888             else if (bottom_blob.elempack == 4 && elempack == 4)
889             {
890                 pipeline = pipeline_concat_pack4[b % 2];
891             }
892             else if (bottom_blob.elempack == 4 && elempack == 1)
893             {
894                 pipeline = pipeline_concat_pack4to1[b % 2];
895             }
896             else if (bottom_blob.elempack == 8 && elempack == 8)
897             {
898                 pipeline = pipeline_concat_pack8[b % 2];
899             }
900             else if (bottom_blob.elempack == 8 && elempack == 4)
901             {
902                 pipeline = pipeline_concat_pack8to4[b % 2];
903             }
904             else if (bottom_blob.elempack == 8 && elempack == 1)
905             {
906                 pipeline = pipeline_concat_pack8to1[b % 2];
907             }
908 
909             cmd.record_pipeline(pipeline, bindings, constants, bottom_blob);
910 
911             hoffset += bottom_blob.h * bottom_blob.elempack / elempack;
912         }
913 
914         // packing
915         if (elempack < out_elempack)
916         {
917             vkdev->convert_packing(top_blob_unpacked, top_blob, out_elempack, cmd, opt);
918         }
919 
920         return 0;
921     }
922 
923     if (dims == 2 && positive_axis == 1)
924     {
925         // interleave image row
926         int h = bottom_blobs[0].h;
927         size_t elemsize = bottom_blobs[0].elemsize;
928         int elempack = bottom_blobs[0].elempack;
929 
930         // total width
931         int top_w = 0;
932         for (size_t b = 0; b < bottom_blobs.size(); b++)
933         {
934             const VkImageMat& bottom_blob = bottom_blobs[b];
935             top_w += bottom_blob.w;
936         }
937 
938         VkImageMat& top_blob = top_blobs[0];
939         top_blob.create(top_w, h, elemsize, elempack, opt.blob_vkallocator);
940         if (top_blob.empty())
941             return -100;
942 
943         int woffset = 0;
944         for (size_t b = 0; b < bottom_blobs.size(); b++)
945         {
946             const VkImageMat& bottom_blob = bottom_blobs[b];
947 
948             std::vector<VkImageMat> bindings(2);
949             bindings[0] = bottom_blob;
950             bindings[1] = top_blob;
951 
952             std::vector<vk_constant_type> constants(11);
953             constants[0].i = bottom_blob.dims;
954             constants[1].i = bottom_blob.w;
955             constants[2].i = bottom_blob.h;
956             constants[3].i = bottom_blob.c;
957             constants[4].i = 0; //bottom_blob.cstep;
958             constants[5].i = top_blob.dims;
959             constants[6].i = top_blob.w;
960             constants[7].i = top_blob.h;
961             constants[8].i = top_blob.c;
962             constants[9].i = 0; //top_blob.cstep;
963             constants[10].i = woffset;
964 
965             const Pipeline* pipeline = elempack == 8 ? pipeline_concat_pack8[b % 2]
966                                        : elempack == 4 ? pipeline_concat_pack4[b % 2]
967                                        : pipeline_concat[b % 2];
968 
969             cmd.record_pipeline(pipeline, bindings, constants, bottom_blob);
970 
971             woffset += bottom_blob.w;
972         }
973 
974         return 0;
975     }
976 
977     if (dims == 3 && positive_axis == 0)
978     {
979         // concat dim
980         int w = bottom_blobs[0].w;
981         int h = bottom_blobs[0].h;
982 
983         // total channels
984         size_t elemsize = bottom_blobs[0].elemsize;
985         int elempack = bottom_blobs[0].elempack;
986         int top_channels = 0;
987         for (size_t b = 0; b < bottom_blobs.size(); b++)
988         {
989             const VkImageMat& bottom_blob = bottom_blobs[b];
990             elemsize = std::min(elemsize, bottom_blob.elemsize);
991             elempack = std::min(elempack, bottom_blob.elempack);
992             top_channels += bottom_blob.c * bottom_blob.elempack;
993         }
994 
995         int out_elempack = opt.use_shader_pack8 && top_channels % 8 == 0 ? 8 : top_channels % 4 == 0 ? 4 : 1;
996         size_t out_elemsize = elemsize / elempack * out_elempack;
997 
998         if (opt.use_fp16_packed && !opt.use_fp16_storage)
999         {
1000             if (out_elempack == 8) out_elemsize = 8 * 2u;
1001             if (out_elempack == 4) out_elemsize = 4 * 2u;
1002             if (out_elempack == 1) out_elemsize = 4u;
1003         }
1004 
1005         VkImageMat& top_blob = top_blobs[0];
1006         top_blob.create(w, h, top_channels / out_elempack, out_elemsize, out_elempack, opt.blob_vkallocator);
1007         if (top_blob.empty())
1008             return -100;
1009 
1010         VkImageMat top_blob_unpacked = top_blob;
1011         if (elempack < out_elempack)
1012         {
1013             top_blob_unpacked.create(w, h, top_channels / elempack, elemsize, elempack, opt.workspace_vkallocator);
1014             if (top_blob_unpacked.empty())
1015                 return -100;
1016         }
1017 
1018         int coffset = 0;
1019         for (size_t b = 0; b < bottom_blobs.size(); b++)
1020         {
1021             const VkImageMat& bottom_blob = bottom_blobs[b];
1022 
1023             std::vector<VkImageMat> bindings(2);
1024             bindings[0] = bottom_blob;
1025             bindings[1] = top_blob_unpacked;
1026 
1027             std::vector<vk_constant_type> constants(11);
1028             constants[0].i = bottom_blob.dims;
1029             constants[1].i = bottom_blob.w;
1030             constants[2].i = bottom_blob.h;
1031             constants[3].i = bottom_blob.c;
1032             constants[4].i = 0; //bottom_blob.cstep;
1033             constants[5].i = top_blob_unpacked.dims;
1034             constants[6].i = top_blob_unpacked.w;
1035             constants[7].i = top_blob_unpacked.h;
1036             constants[8].i = top_blob_unpacked.c;
1037             constants[9].i = 0; //top_blob_unpacked.cstep;
1038             constants[10].i = coffset;
1039 
1040             const Pipeline* pipeline = 0;
1041             if (bottom_blob.elempack == 1 && elempack == 1)
1042             {
1043                 pipeline = pipeline_concat[b % 2];
1044             }
1045             else if (bottom_blob.elempack == 4 && elempack == 4)
1046             {
1047                 pipeline = pipeline_concat_pack4[b % 2];
1048             }
1049             else if (bottom_blob.elempack == 4 && elempack == 1)
1050             {
1051                 pipeline = pipeline_concat_pack4to1[b % 2];
1052             }
1053             else if (bottom_blob.elempack == 8 && elempack == 8)
1054             {
1055                 pipeline = pipeline_concat_pack8[b % 2];
1056             }
1057             else if (bottom_blob.elempack == 8 && elempack == 4)
1058             {
1059                 pipeline = pipeline_concat_pack8to4[b % 2];
1060             }
1061             else if (bottom_blob.elempack == 8 && elempack == 1)
1062             {
1063                 pipeline = pipeline_concat_pack8to1[b % 2];
1064             }
1065 
1066             cmd.record_pipeline(pipeline, bindings, constants, bottom_blob);
1067 
1068             coffset += bottom_blob.c * bottom_blob.elempack / elempack;
1069         }
1070 
1071         // packing
1072         if (elempack < out_elempack)
1073         {
1074             vkdev->convert_packing(top_blob_unpacked, top_blob, out_elempack, cmd, opt);
1075         }
1076 
1077         return 0;
1078     }
1079 
1080     if (dims == 3 && positive_axis == 1)
1081     {
1082         // interleave dim height
1083         int w = bottom_blobs[0].w;
1084         int channels = bottom_blobs[0].c;
1085         size_t elemsize = bottom_blobs[0].elemsize;
1086         int elempack = bottom_blobs[0].elempack;
1087 
1088         // total height
1089         int top_h = 0;
1090         for (size_t b = 0; b < bottom_blobs.size(); b++)
1091         {
1092             const VkImageMat& bottom_blob = bottom_blobs[b];
1093             top_h += bottom_blob.h;
1094         }
1095 
1096         VkImageMat& top_blob = top_blobs[0];
1097         top_blob.create(w, top_h, channels, elemsize, elempack, opt.blob_vkallocator);
1098         if (top_blob.empty())
1099             return -100;
1100 
1101         int hoffset = 0;
1102         for (size_t b = 0; b < bottom_blobs.size(); b++)
1103         {
1104             const VkImageMat& bottom_blob = bottom_blobs[b];
1105 
1106             std::vector<VkImageMat> bindings(2);
1107             bindings[0] = bottom_blob;
1108             bindings[1] = top_blob;
1109 
1110             std::vector<vk_constant_type> constants(11);
1111             constants[0].i = bottom_blob.dims;
1112             constants[1].i = bottom_blob.w;
1113             constants[2].i = bottom_blob.h;
1114             constants[3].i = bottom_blob.c;
1115             constants[4].i = 0; //bottom_blob.cstep;
1116             constants[5].i = top_blob.dims;
1117             constants[6].i = top_blob.w;
1118             constants[7].i = top_blob.h;
1119             constants[8].i = top_blob.c;
1120             constants[9].i = 0; //top_blob.cstep;
1121             constants[10].i = hoffset;
1122 
1123             const Pipeline* pipeline = elempack == 8 ? pipeline_concat_pack8[b % 2]
1124                                        : elempack == 4 ? pipeline_concat_pack4[b % 2]
1125                                        : pipeline_concat[b % 2];
1126 
1127             cmd.record_pipeline(pipeline, bindings, constants, bottom_blob);
1128 
1129             hoffset += bottom_blob.h;
1130         }
1131 
1132         return 0;
1133     }
1134 
1135     if (dims == 3 && positive_axis == 2)
1136     {
1137         // interleave dim width
1138         int h = bottom_blobs[0].h;
1139         int channels = bottom_blobs[0].c;
1140         size_t elemsize = bottom_blobs[0].elemsize;
1141         int elempack = bottom_blobs[0].elempack;
1142 
1143         // total height
1144         int top_w = 0;
1145         for (size_t b = 0; b < bottom_blobs.size(); b++)
1146         {
1147             const VkImageMat& bottom_blob = bottom_blobs[b];
1148             top_w += bottom_blob.w;
1149         }
1150 
1151         VkImageMat& top_blob = top_blobs[0];
1152         top_blob.create(top_w, h, channels, elemsize, elempack, opt.blob_vkallocator);
1153         if (top_blob.empty())
1154             return -100;
1155 
1156         int woffset = 0;
1157         for (size_t b = 0; b < bottom_blobs.size(); b++)
1158         {
1159             const VkImageMat& bottom_blob = bottom_blobs[b];
1160 
1161             std::vector<VkImageMat> bindings(2);
1162             bindings[0] = bottom_blob;
1163             bindings[1] = top_blob;
1164 
1165             std::vector<vk_constant_type> constants(11);
1166             constants[0].i = bottom_blob.dims;
1167             constants[1].i = bottom_blob.w;
1168             constants[2].i = bottom_blob.h;
1169             constants[3].i = bottom_blob.c;
1170             constants[4].i = 0; //bottom_blob.cstep;
1171             constants[5].i = top_blob.dims;
1172             constants[6].i = top_blob.w;
1173             constants[7].i = top_blob.h;
1174             constants[8].i = top_blob.c;
1175             constants[9].i = 0; //top_blob.cstep;
1176             constants[10].i = woffset;
1177 
1178             const Pipeline* pipeline = elempack == 8 ? pipeline_concat_pack8[b % 2]
1179                                        : elempack == 4 ? pipeline_concat_pack4[b % 2]
1180                                        : pipeline_concat[b % 2];
1181 
1182             cmd.record_pipeline(pipeline, bindings, constants, bottom_blob);
1183 
1184             woffset += bottom_blob.w;
1185         }
1186 
1187         return 0;
1188     }
1189 
1190     return 0;
1191 }
1192 
1193 } // namespace ncnn
1194