1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
15 #include "concat_x86.h"
16 
17 namespace ncnn {
18 
Concat_x86()19 Concat_x86::Concat_x86()
20 {
21 #if __SSE2__
22     support_packing = true;
23 #endif // __SSE2__
24 }
25 
forward(const std::vector<Mat> & bottom_blobs,std::vector<Mat> & top_blobs,const Option & opt) const26 int Concat_x86::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
27 {
28     int dims = bottom_blobs[0].dims;
29     int positive_axis = axis < 0 ? dims + axis : axis;
30 
31     if (dims == 1) // positive_axis == 0
32     {
33         // concat vector
34         // total length
35         size_t elemsize = bottom_blobs[0].elemsize;
36         int elempack = bottom_blobs[0].elempack;
37         int top_w = 0;
38         for (size_t b = 0; b < bottom_blobs.size(); b++)
39         {
40             const Mat& bottom_blob = bottom_blobs[b];
41             top_w += bottom_blob.w * bottom_blob.elempack;
42         }
43 
44         int out_elempack = 1;
45 #if __SSE2__
46         if (opt.use_packing_layout)
47         {
48 #if __AVX__
49             out_elempack = top_w % 8 == 0 ? 8 : top_w % 4 == 0 ? 4 : 1;
50 #else
51             out_elempack = top_w % 4 == 0 ? 4 : 1;
52 #endif
53         }
54 #endif // __SSE2__
55         size_t out_elemsize = elemsize / elempack * out_elempack;
56 
57         Mat& top_blob = top_blobs[0];
58         top_blob.create(top_w / out_elempack, out_elemsize, out_elempack, opt.blob_allocator);
59         if (top_blob.empty())
60             return -100;
61 
62         float* outptr = top_blob;
63         for (size_t b = 0; b < bottom_blobs.size(); b++)
64         {
65             const Mat& bottom_blob = bottom_blobs[b];
66 
67             const float* ptr = bottom_blob;
68             memcpy(outptr, ptr, bottom_blob.w * bottom_blob.elemsize);
69 
70             outptr += bottom_blob.w * bottom_blob.elempack;
71         }
72     }
73 
74     if (dims == 2 && positive_axis == 0)
75     {
76         // concat image
77         int w = bottom_blobs[0].w;
78 
79         // total height
80         size_t elemsize = bottom_blobs[0].elemsize;
81         int elempack = bottom_blobs[0].elempack;
82         int top_h = 0;
83         for (size_t b = 0; b < bottom_blobs.size(); b++)
84         {
85             const Mat& bottom_blob = bottom_blobs[b];
86             elemsize = std::min(elemsize, bottom_blob.elemsize);
87             elempack = std::min(elempack, bottom_blob.elempack);
88             top_h += bottom_blob.h * bottom_blob.elempack;
89         }
90 
91         int out_elempack = 1;
92 #if __SSE2__
93         if (opt.use_packing_layout)
94         {
95 #if __AVX__
96             out_elempack = top_h % 8 == 0 ? 8 : top_h % 4 == 0 ? 4 : 1;
97 #else
98             out_elempack = top_h % 4 == 0 ? 4 : 1;
99 #endif
100         }
101 #endif // __SSE2__
102         size_t out_elemsize = elemsize / elempack * out_elempack;
103 
104         Mat& top_blob = top_blobs[0];
105         top_blob.create(w, top_h / out_elempack, out_elemsize, out_elempack, opt.blob_allocator);
106         if (top_blob.empty())
107             return -100;
108 
109         Mat top_blob_unpacked = top_blob;
110         if (elempack < out_elempack)
111         {
112             top_blob_unpacked.create(w, top_h / elempack, elemsize, elempack, opt.workspace_allocator);
113             if (top_blob_unpacked.empty())
114                 return -100;
115         }
116 
117         float* outptr = top_blob_unpacked;
118         for (size_t b = 0; b < bottom_blobs.size(); b++)
119         {
120             const Mat& bottom_blob = bottom_blobs[b];
121 
122 #if __AVX__
123             if (bottom_blob.elempack == 8 && elempack == 4)
124             {
125                 for (int i = 0; i < bottom_blob.h; i++)
126                 {
127                     const float* r0 = bottom_blob.row(i);
128 
129                     float* outptr0 = outptr;
130                     float* outptr1 = outptr + w * 4;
131 
132                     for (int j = 0; j < w; j++)
133                     {
134                         outptr0[0] = r0[0];
135                         outptr0[1] = r0[1];
136                         outptr0[2] = r0[2];
137                         outptr0[3] = r0[3];
138                         outptr1[0] = r0[4];
139                         outptr1[1] = r0[5];
140                         outptr1[2] = r0[6];
141                         outptr1[3] = r0[7];
142 
143                         outptr0 += 4;
144                         outptr1 += 4;
145                         r0 += 8;
146                     }
147 
148                     outptr += w * 8;
149                 }
150             }
151             if (bottom_blob.elempack == 8 && elempack == 1)
152             {
153                 for (int i = 0; i < bottom_blob.h; i++)
154                 {
155                     const float* r0 = bottom_blob.row(i);
156 
157                     float* outptr0 = outptr;
158                     float* outptr1 = outptr + w;
159                     float* outptr2 = outptr + w * 2;
160                     float* outptr3 = outptr + w * 3;
161                     float* outptr4 = outptr + w * 4;
162                     float* outptr5 = outptr + w * 5;
163                     float* outptr6 = outptr + w * 6;
164                     float* outptr7 = outptr + w * 7;
165 
166                     for (int j = 0; j < w; j++)
167                     {
168                         *outptr0++ = r0[0];
169                         *outptr1++ = r0[1];
170                         *outptr2++ = r0[2];
171                         *outptr3++ = r0[3];
172                         *outptr4++ = r0[4];
173                         *outptr5++ = r0[5];
174                         *outptr6++ = r0[6];
175                         *outptr7++ = r0[7];
176 
177                         r0 += 8;
178                     }
179 
180                     outptr += w * 8;
181                 }
182             }
183 #endif // __AVX__
184             if (bottom_blob.elempack == 4 && elempack == 1)
185             {
186                 for (int i = 0; i < bottom_blob.h; i++)
187                 {
188                     const float* r0 = bottom_blob.row(i);
189 
190                     float* outptr0 = outptr;
191                     float* outptr1 = outptr + w;
192                     float* outptr2 = outptr + w * 2;
193                     float* outptr3 = outptr + w * 3;
194 
195                     for (int j = 0; j < w; j++)
196                     {
197                         *outptr0++ = r0[0];
198                         *outptr1++ = r0[1];
199                         *outptr2++ = r0[2];
200                         *outptr3++ = r0[3];
201 
202                         r0 += 4;
203                     }
204 
205                     outptr += w * 4;
206                 }
207             }
208             if (bottom_blob.elempack == elempack) // 1-1 4-4 8-8
209             {
210                 int size = w * bottom_blob.h;
211 
212                 const float* ptr = bottom_blob;
213                 memcpy(outptr, ptr, size * bottom_blob.elemsize);
214 
215                 outptr += size * bottom_blob.elempack;
216             }
217         }
218 
219         // packing
220         if (elempack < out_elempack)
221         {
222             convert_packing(top_blob_unpacked, top_blob, out_elempack, opt);
223         }
224     }
225 
226     if (dims == 2 && positive_axis == 1)
227     {
228         // interleave image row
229         int h = bottom_blobs[0].h;
230         size_t elemsize = bottom_blobs[0].elemsize;
231         int elempack = bottom_blobs[0].elempack;
232 
233         // total width
234         int top_w = 0;
235         for (size_t b = 0; b < bottom_blobs.size(); b++)
236         {
237             const Mat& bottom_blob = bottom_blobs[b];
238             top_w += bottom_blob.w;
239         }
240 
241         Mat& top_blob = top_blobs[0];
242         top_blob.create(top_w, h, elemsize, elempack, opt.blob_allocator);
243         if (top_blob.empty())
244             return -100;
245 
246         #pragma omp parallel for num_threads(opt.num_threads)
247         for (int i = 0; i < h; i++)
248         {
249             float* outptr = top_blob.row(i);
250             for (size_t b = 0; b < bottom_blobs.size(); b++)
251             {
252                 const Mat& bottom_blob = bottom_blobs[b];
253 
254                 const float* ptr = bottom_blob.row(i);
255                 memcpy(outptr, ptr, bottom_blob.w * elemsize);
256 
257                 outptr += bottom_blob.w * elempack;
258             }
259         }
260     }
261 
262     if (dims == 3 && positive_axis == 0)
263     {
264         // concat dim
265         int w = bottom_blobs[0].w;
266         int h = bottom_blobs[0].h;
267 
268         // total channels
269         size_t elemsize = bottom_blobs[0].elemsize;
270         int elempack = bottom_blobs[0].elempack;
271         int top_channels = 0;
272         for (size_t b = 0; b < bottom_blobs.size(); b++)
273         {
274             const Mat& bottom_blob = bottom_blobs[b];
275             elemsize = std::min(elemsize, bottom_blob.elemsize);
276             elempack = std::min(elempack, bottom_blob.elempack);
277             top_channels += bottom_blob.c * bottom_blob.elempack;
278         }
279 
280         int out_elempack = 1;
281 #if __SSE2__
282         if (opt.use_packing_layout)
283         {
284 #if __AVX__
285             out_elempack = top_channels % 8 == 0 ? 8 : top_channels % 4 == 0 ? 4 : 1;
286 #else
287             out_elempack = top_channels % 4 == 0 ? 4 : 1;
288 #endif
289         }
290 #endif // __SSE2__
291         size_t out_elemsize = elemsize / elempack * out_elempack;
292 
293         Mat& top_blob = top_blobs[0];
294         top_blob.create(w, h, top_channels / out_elempack, out_elemsize, out_elempack, opt.blob_allocator);
295         if (top_blob.empty())
296             return -100;
297 
298         Mat top_blob_unpacked = top_blob;
299         if (elempack < out_elempack)
300         {
301             top_blob_unpacked.create(w, h, top_channels / elempack, elemsize, elempack, opt.workspace_allocator);
302             if (top_blob_unpacked.empty())
303                 return -100;
304         }
305 
306         int p = 0;
307         for (size_t b = 0; b < bottom_blobs.size(); b++)
308         {
309             const Mat& bottom_blob = bottom_blobs[b];
310 
311 #if __AVX__
312             if (bottom_blob.elempack == 8 && elempack == 4)
313             {
314                 int size = bottom_blob.w * bottom_blob.h;
315 
316                 for (int q = 0; q < bottom_blob.c; q++)
317                 {
318                     const float* r0 = bottom_blob.channel(q);
319 
320                     float* outptr0 = top_blob_unpacked.channel(p);
321                     float* outptr1 = top_blob_unpacked.channel(p + 1);
322 
323                     for (int i = 0; i < size; i++)
324                     {
325                         outptr0[0] = r0[0];
326                         outptr0[1] = r0[1];
327                         outptr0[2] = r0[2];
328                         outptr0[3] = r0[3];
329                         outptr1[0] = r0[4];
330                         outptr1[1] = r0[5];
331                         outptr1[2] = r0[6];
332                         outptr1[3] = r0[7];
333 
334                         outptr0 += 4;
335                         outptr1 += 4;
336                         r0 += 8;
337                     }
338 
339                     p += 2;
340                 }
341             }
342             if (bottom_blob.elempack == 8 && elempack == 1)
343             {
344                 int size = bottom_blob.w * bottom_blob.h;
345 
346                 for (int q = 0; q < bottom_blob.c; q++)
347                 {
348                     const float* r0 = bottom_blob.channel(q);
349 
350                     float* outptr0 = top_blob_unpacked.channel(p);
351                     float* outptr1 = top_blob_unpacked.channel(p + 1);
352                     float* outptr2 = top_blob_unpacked.channel(p + 2);
353                     float* outptr3 = top_blob_unpacked.channel(p + 3);
354                     float* outptr4 = top_blob_unpacked.channel(p + 4);
355                     float* outptr5 = top_blob_unpacked.channel(p + 5);
356                     float* outptr6 = top_blob_unpacked.channel(p + 6);
357                     float* outptr7 = top_blob_unpacked.channel(p + 7);
358 
359                     for (int i = 0; i < size; i++)
360                     {
361                         *outptr0++ = r0[0];
362                         *outptr1++ = r0[1];
363                         *outptr2++ = r0[2];
364                         *outptr3++ = r0[3];
365                         *outptr4++ = r0[4];
366                         *outptr5++ = r0[5];
367                         *outptr6++ = r0[6];
368                         *outptr7++ = r0[7];
369 
370                         r0 += 8;
371                     }
372 
373                     p += 8;
374                 }
375             }
376 #endif // __AVX__
377             if (bottom_blob.elempack == 4 && elempack == 1)
378             {
379                 int size = bottom_blob.w * bottom_blob.h;
380 
381                 for (int q = 0; q < bottom_blob.c; q++)
382                 {
383                     const float* r0 = bottom_blob.channel(q);
384 
385                     float* outptr0 = top_blob_unpacked.channel(p);
386                     float* outptr1 = top_blob_unpacked.channel(p + 1);
387                     float* outptr2 = top_blob_unpacked.channel(p + 2);
388                     float* outptr3 = top_blob_unpacked.channel(p + 3);
389 
390                     for (int i = 0; i < size; i++)
391                     {
392                         *outptr0++ = r0[0];
393                         *outptr1++ = r0[1];
394                         *outptr2++ = r0[2];
395                         *outptr3++ = r0[3];
396 
397                         r0 += 4;
398                     }
399 
400                     p += 4;
401                 }
402             }
403             if (bottom_blob.elempack == elempack) // 1-1 4-4 8-8
404             {
405                 int size = bottom_blob.total();
406 
407                 const float* ptr = bottom_blob;
408                 float* outptr = top_blob_unpacked.channel(p);
409                 memcpy(outptr, ptr, size * bottom_blob.elemsize);
410 
411                 p += bottom_blob.c;
412             }
413         }
414 
415         // packing
416         if (elempack < out_elempack)
417         {
418             convert_packing(top_blob_unpacked, top_blob, out_elempack, opt);
419         }
420     }
421 
422     if (dims == 3 && positive_axis == 1)
423     {
424         // interleave dim height
425         int w = bottom_blobs[0].w;
426         int channels = bottom_blobs[0].c;
427         size_t elemsize = bottom_blobs[0].elemsize;
428         int elempack = bottom_blobs[0].elempack;
429 
430         // total height
431         int top_h = 0;
432         for (size_t b = 0; b < bottom_blobs.size(); b++)
433         {
434             const Mat& bottom_blob = bottom_blobs[b];
435             top_h += bottom_blob.h;
436         }
437 
438         Mat& top_blob = top_blobs[0];
439         top_blob.create(w, top_h, channels, elemsize, elempack, opt.blob_allocator);
440         if (top_blob.empty())
441             return -100;
442 
443         #pragma omp parallel for num_threads(opt.num_threads)
444         for (int q = 0; q < channels; q++)
445         {
446             float* outptr = top_blob.channel(q);
447 
448             for (size_t b = 0; b < bottom_blobs.size(); b++)
449             {
450                 const Mat& bottom_blob = bottom_blobs[b];
451 
452                 int size = bottom_blob.w * bottom_blob.h;
453 
454                 const float* ptr = bottom_blob.channel(q);
455                 memcpy(outptr, ptr, size * elemsize);
456 
457                 outptr += size * elempack;
458             }
459         }
460     }
461 
462     if (dims == 3 && positive_axis == 2)
463     {
464         // interleave dim width
465         int h = bottom_blobs[0].h;
466         int channels = bottom_blobs[0].c;
467         size_t elemsize = bottom_blobs[0].elemsize;
468         int elempack = bottom_blobs[0].elempack;
469 
470         // total height
471         int top_w = 0;
472         for (size_t b = 0; b < bottom_blobs.size(); b++)
473         {
474             const Mat& bottom_blob = bottom_blobs[b];
475             top_w += bottom_blob.w;
476         }
477 
478         Mat& top_blob = top_blobs[0];
479         top_blob.create(top_w, h, channels, elemsize, elempack, opt.blob_allocator);
480         if (top_blob.empty())
481             return -100;
482 
483         #pragma omp parallel for num_threads(opt.num_threads)
484         for (int q = 0; q < channels; q++)
485         {
486             float* outptr = top_blob.channel(q);
487 
488             for (int i = 0; i < h; i++)
489             {
490                 for (size_t b = 0; b < bottom_blobs.size(); b++)
491                 {
492                     const Mat& bottom_blob = bottom_blobs[b];
493 
494                     const float* ptr = bottom_blob.channel(q).row(i);
495                     memcpy(outptr, ptr, bottom_blob.w * elemsize);
496 
497                     outptr += bottom_blob.w * elempack;
498                 }
499             }
500         }
501     }
502 
503     return 0;
504 }
505 
506 } // namespace ncnn
507