1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
15 #include "slice_x86.h"
16 
17 #if __SSE2__
18 #include <emmintrin.h>
19 #if __AVX__
20 #include <immintrin.h>
21 #endif // __AVX__
22 #endif // __SSE2__
23 
24 namespace ncnn {
25 
Slice_x86()26 Slice_x86::Slice_x86()
27 {
28 #if __SSE2__
29     support_packing = true;
30 #endif // __SSE2__
31 }
32 
forward(const std::vector<Mat> & bottom_blobs,std::vector<Mat> & top_blobs,const Option & opt) const33 int Slice_x86::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
34 {
35     const Mat& bottom_blob = bottom_blobs[0];
36     int dims = bottom_blob.dims;
37     size_t elemsize = bottom_blob.elemsize;
38     int elempack = bottom_blob.elempack;
39     const int* slices_ptr = slices;
40     int positive_axis = axis < 0 ? dims + axis : axis;
41 
42     if (dims == 1) // positive_axis == 0
43     {
44         // slice vector
45         int w = bottom_blob.w * elempack;
46         int q = 0;
47         for (size_t i = 0; i < top_blobs.size(); i++)
48         {
49             int slice = slices_ptr[i];
50             if (slice == -233)
51             {
52                 slice = (w - q) / (top_blobs.size() - i);
53             }
54 
55             int out_elempack = 1;
56 #if __SSE2__
57             if (opt.use_packing_layout)
58             {
59 #if __AVX__
60                 out_elempack = slice % 8 == 0 ? 8 : slice % 4 == 0 ? 4 : 1;
61 #else
62                 out_elempack = slice % 4 == 0 ? 4 : 1;
63 #endif
64             }
65 #endif // __SSE2__
66             size_t out_elemsize = elemsize / elempack * out_elempack;
67 
68             Mat& top_blob = top_blobs[i];
69             top_blob.create(slice / out_elempack, out_elemsize, out_elempack, opt.blob_allocator);
70             if (top_blob.empty())
71                 return -100;
72 
73             const float* ptr = (const float*)bottom_blob + q;
74             float* outptr = top_blob;
75             memcpy(outptr, ptr, top_blob.w * top_blob.elemsize);
76 
77             q += slice;
78         }
79     }
80 
81     if (dims == 2 && positive_axis == 0)
82     {
83         // slice image height
84         int w = bottom_blob.w;
85         int h = bottom_blob.h * elempack;
86 
87         int q = 0;
88         for (size_t i = 0; i < top_blobs.size(); i++)
89         {
90             int slice = slices_ptr[i];
91             if (slice == -233)
92             {
93                 slice = (h - q) / (top_blobs.size() - i);
94             }
95 
96             int out_elempack = 1;
97 #if __SSE2__
98             if (opt.use_packing_layout)
99             {
100 #if __AVX__
101                 out_elempack = slice % 8 == 0 ? 8 : slice % 4 == 0 ? 4 : 1;
102 #else
103                 out_elempack = slice % 4 == 0 ? 4 : 1;
104 #endif
105             }
106 #endif // __SSE2__
107             size_t out_elemsize = elemsize / elempack * out_elempack;
108 
109             Mat& top_blob = top_blobs[i];
110             top_blob.create(w, slice / out_elempack, out_elemsize, out_elempack, opt.blob_allocator);
111             if (top_blob.empty())
112                 return -100;
113 
114             q += slice;
115         }
116 
117         size_t out_elemsize = top_blobs[0].elemsize;
118         int out_elempack = top_blobs[0].elempack;
119         for (size_t i = 0; i < top_blobs.size(); i++)
120         {
121             out_elemsize = std::min(out_elemsize, top_blobs[i].elemsize);
122             out_elempack = std::min(out_elempack, top_blobs[i].elempack);
123         }
124 
125         Mat bottom_blob_unpacked = bottom_blob;
126         if (elempack > out_elempack)
127         {
128             convert_packing(bottom_blob, bottom_blob_unpacked, out_elempack, opt);
129         }
130 
131         const float* ptr = bottom_blob_unpacked;
132         for (size_t i = 0; i < top_blobs.size(); i++)
133         {
134             Mat& top_blob = top_blobs[i];
135 
136 #if __SSE2__
137 #if __AVX__
138             if (out_elempack == 4 && top_blob.elempack == 8)
139             {
140                 for (int j = 0; j < top_blob.h; j++)
141                 {
142                     const float* r0 = ptr;
143                     const float* r1 = ptr + w * 4;
144 
145                     float* outptr0 = top_blob.row(j);
146 
147                     for (int j = 0; j < w; j++)
148                     {
149                         outptr0[0] = r0[0];
150                         outptr0[1] = r0[1];
151                         outptr0[2] = r0[2];
152                         outptr0[3] = r0[3];
153                         outptr0[4] = r1[0];
154                         outptr0[5] = r1[1];
155                         outptr0[6] = r1[2];
156                         outptr0[7] = r1[3];
157 
158                         r0 += 4;
159                         r1 += 4;
160                         outptr0 += 8;
161                     }
162 
163                     ptr += w * 8;
164                 }
165             }
166             if (out_elempack == 1 && top_blob.elempack == 8)
167             {
168                 for (int j = 0; j < top_blob.h; j++)
169                 {
170                     const float* r0 = ptr;
171                     const float* r1 = ptr + w;
172                     const float* r2 = ptr + w * 2;
173                     const float* r3 = ptr + w * 3;
174                     const float* r4 = ptr + w * 4;
175                     const float* r5 = ptr + w * 5;
176                     const float* r6 = ptr + w * 6;
177                     const float* r7 = ptr + w * 7;
178 
179                     float* outptr0 = top_blob.row(j);
180 
181                     for (int j = 0; j < w; j++)
182                     {
183                         outptr0[0] = *r0++;
184                         outptr0[1] = *r1++;
185                         outptr0[2] = *r2++;
186                         outptr0[3] = *r3++;
187                         outptr0[4] = *r4++;
188                         outptr0[5] = *r5++;
189                         outptr0[6] = *r6++;
190                         outptr0[7] = *r7++;
191 
192                         outptr0 += 8;
193                     }
194 
195                     ptr += w * 8;
196                 }
197             }
198 #endif // __AVX__
199             if (out_elempack == 1 && top_blob.elempack == 4)
200             {
201                 for (int j = 0; j < top_blob.h; j++)
202                 {
203                     const float* r0 = ptr;
204                     const float* r1 = ptr + w;
205                     const float* r2 = ptr + w * 2;
206                     const float* r3 = ptr + w * 3;
207 
208                     float* outptr0 = top_blob.row(j);
209 
210                     for (int j = 0; j < w; j++)
211                     {
212                         outptr0[0] = *r0++;
213                         outptr0[1] = *r1++;
214                         outptr0[2] = *r2++;
215                         outptr0[3] = *r3++;
216 
217                         outptr0 += 4;
218                     }
219 
220                     ptr += w * 4;
221                 }
222             }
223 #endif // __SSE2__
224             if (out_elempack == top_blob.elempack)
225             {
226                 // 1-1 4-4 8-8
227                 int size = w * top_blob.h;
228 
229                 float* outptr = top_blob;
230                 memcpy(outptr, ptr, size * top_blob.elemsize);
231 
232                 ptr += size * top_blob.elempack;
233             }
234         }
235     }
236 
237     if (dims == 2 && positive_axis == 1)
238     {
239         // slice image width
240         int w = bottom_blob.w;
241         int h = bottom_blob.h;
242 
243         int q = 0;
244         for (size_t i = 0; i < top_blobs.size(); i++)
245         {
246             int slice = slices_ptr[i];
247             if (slice == -233)
248             {
249                 slice = (w - q) / (top_blobs.size() - i);
250             }
251 
252             Mat& top_blob = top_blobs[i];
253             top_blob.create(slice, h, elemsize, elempack, opt.blob_allocator);
254             if (top_blob.empty())
255                 return -100;
256 
257             q += slice;
258         }
259 
260         #pragma omp parallel for num_threads(opt.num_threads)
261         for (int j = 0; j < h; j++)
262         {
263             const float* ptr = bottom_blob.row<const float>(j);
264             for (size_t i = 0; i < top_blobs.size(); i++)
265             {
266                 Mat& top_blob = top_blobs[i];
267 
268                 float* outptr = top_blob.row(j);
269                 memcpy(outptr, ptr, top_blob.w * elemsize);
270 
271                 ptr += top_blob.w * elempack;
272             }
273         }
274     }
275 
276     if (dims == 3 && positive_axis == 0)
277     {
278         // slice dim channel
279         int w = bottom_blob.w;
280         int h = bottom_blob.h;
281         int channels = bottom_blob.c * elempack;
282 
283         int q = 0;
284         for (size_t i = 0; i < top_blobs.size(); i++)
285         {
286             int slice = slices_ptr[i];
287             if (slice == -233)
288             {
289                 slice = (channels - q) / (top_blobs.size() - i);
290             }
291 
292             int out_elempack = 1;
293 #if __SSE2__
294             if (opt.use_packing_layout)
295             {
296 #if __AVX__
297                 out_elempack = slice % 8 == 0 ? 8 : slice % 4 == 0 ? 4 : 1;
298 #else
299                 out_elempack = slice % 4 == 0 ? 4 : 1;
300 #endif
301             }
302 #endif // __SSE2__
303             size_t out_elemsize = elemsize / elempack * out_elempack;
304 
305             Mat& top_blob = top_blobs[i];
306             top_blob.create(w, h, slice / out_elempack, out_elemsize, out_elempack, opt.blob_allocator);
307             if (top_blob.empty())
308                 return -100;
309 
310             q += slice;
311         }
312 
313         size_t out_elemsize = top_blobs[0].elemsize;
314         int out_elempack = top_blobs[0].elempack;
315         for (size_t i = 0; i < top_blobs.size(); i++)
316         {
317             out_elemsize = std::min(out_elemsize, top_blobs[i].elemsize);
318             out_elempack = std::min(out_elempack, top_blobs[i].elempack);
319         }
320 
321         Mat bottom_blob_unpacked = bottom_blob;
322         if (elempack > out_elempack)
323         {
324             convert_packing(bottom_blob, bottom_blob_unpacked, out_elempack, opt);
325         }
326 
327         int p = 0;
328         for (size_t i = 0; i < top_blobs.size(); i++)
329         {
330             Mat& top_blob = top_blobs[i];
331 
332 #if __SSE2__
333 #if __AVX__
334             if (out_elempack == 4 && top_blob.elempack == 8)
335             {
336                 int size = top_blob.w * top_blob.h;
337 
338                 for (int q = 0; q < top_blob.c; q++)
339                 {
340                     const float* r0 = bottom_blob_unpacked.channel(p);
341                     const float* r1 = bottom_blob_unpacked.channel(p + 1);
342 
343                     float* outptr0 = top_blob.channel(q);
344 
345                     for (int j = 0; j < size; j++)
346                     {
347                         outptr0[0] = r0[0];
348                         outptr0[1] = r0[1];
349                         outptr0[2] = r0[2];
350                         outptr0[3] = r0[3];
351                         outptr0[4] = r1[0];
352                         outptr0[5] = r1[1];
353                         outptr0[6] = r1[2];
354                         outptr0[7] = r1[3];
355 
356                         r0 += 4;
357                         r1 += 4;
358                         outptr0 += 8;
359                     }
360 
361                     p += 2;
362                 }
363             }
364             if (out_elempack == 1 && top_blob.elempack == 8)
365             {
366                 int size = top_blob.w * top_blob.h;
367 
368                 for (int q = 0; q < top_blob.c; q++)
369                 {
370                     const float* r0 = bottom_blob_unpacked.channel(p);
371                     const float* r1 = bottom_blob_unpacked.channel(p + 1);
372                     const float* r2 = bottom_blob_unpacked.channel(p + 2);
373                     const float* r3 = bottom_blob_unpacked.channel(p + 3);
374                     const float* r4 = bottom_blob_unpacked.channel(p + 4);
375                     const float* r5 = bottom_blob_unpacked.channel(p + 5);
376                     const float* r6 = bottom_blob_unpacked.channel(p + 6);
377                     const float* r7 = bottom_blob_unpacked.channel(p + 7);
378 
379                     float* outptr0 = top_blob.channel(q);
380 
381                     for (int j = 0; j < size; j++)
382                     {
383                         outptr0[0] = *r0++;
384                         outptr0[1] = *r1++;
385                         outptr0[2] = *r2++;
386                         outptr0[3] = *r3++;
387                         outptr0[4] = *r4++;
388                         outptr0[5] = *r5++;
389                         outptr0[6] = *r6++;
390                         outptr0[7] = *r7++;
391 
392                         outptr0 += 8;
393                     }
394 
395                     p += 8;
396                 }
397             }
398 #endif // __AVX__
399             if (out_elempack == 1 && top_blob.elempack == 4)
400             {
401                 int size = top_blob.w * top_blob.h;
402 
403                 for (int q = 0; q < top_blob.c; q++)
404                 {
405                     const float* r0 = bottom_blob_unpacked.channel(p);
406                     const float* r1 = bottom_blob_unpacked.channel(p + 1);
407                     const float* r2 = bottom_blob_unpacked.channel(p + 2);
408                     const float* r3 = bottom_blob_unpacked.channel(p + 3);
409 
410                     float* outptr0 = top_blob.channel(q);
411 
412                     for (int j = 0; j < size; j++)
413                     {
414                         outptr0[0] = *r0++;
415                         outptr0[1] = *r1++;
416                         outptr0[2] = *r2++;
417                         outptr0[3] = *r3++;
418 
419                         outptr0 += 4;
420                     }
421 
422                     p += 4;
423                 }
424             }
425 #endif // __SSE2__
426             if (out_elempack == top_blob.elempack)
427             {
428                 // 1-1 4-4 8-8
429                 int size = top_blob.total();
430 
431                 const float* ptr = bottom_blob_unpacked.channel(p);
432                 float* outptr = top_blob;
433                 memcpy(outptr, ptr, size * top_blob.elemsize);
434 
435                 p += top_blob.c;
436             }
437         }
438     }
439 
440     if (dims == 3 && positive_axis == 1)
441     {
442         // slice dim height
443         int w = bottom_blob.w;
444         int h = bottom_blob.h;
445         int channels = bottom_blob.c;
446 
447         int q = 0;
448         for (size_t i = 0; i < top_blobs.size(); i++)
449         {
450             int slice = slices_ptr[i];
451             if (slice == -233)
452             {
453                 slice = (h - q) / (top_blobs.size() - i);
454             }
455 
456             Mat& top_blob = top_blobs[i];
457             top_blob.create(w, slice, channels, elemsize, elempack, opt.blob_allocator);
458             if (top_blob.empty())
459                 return -100;
460 
461             q += slice;
462         }
463 
464         #pragma omp parallel for num_threads(opt.num_threads)
465         for (int p = 0; p < channels; p++)
466         {
467             const float* ptr = bottom_blob.channel(p);
468 
469             for (size_t i = 0; i < top_blobs.size(); i++)
470             {
471                 Mat& top_blob = top_blobs[i];
472 
473                 int size = top_blob.w * top_blob.h;
474 
475                 float* outptr = top_blob.channel(p);
476                 memcpy(outptr, ptr, size * elemsize);
477 
478                 ptr += size * elempack;
479             }
480         }
481     }
482 
483     if (dims == 3 && positive_axis == 2)
484     {
485         // slice dim width
486         int w = bottom_blob.w;
487         int h = bottom_blob.h;
488         int channels = bottom_blob.c;
489 
490         int q = 0;
491         for (size_t i = 0; i < top_blobs.size(); i++)
492         {
493             int slice = slices_ptr[i];
494             if (slice == -233)
495             {
496                 slice = (w - q) / (top_blobs.size() - i);
497             }
498 
499             Mat& top_blob = top_blobs[i];
500             top_blob.create(slice, h, channels, elemsize, elempack, opt.blob_allocator);
501             if (top_blob.empty())
502                 return -100;
503 
504             q += slice;
505         }
506 
507         #pragma omp parallel for num_threads(opt.num_threads)
508         for (int p = 0; p < channels; p++)
509         {
510             const float* ptr = bottom_blob.channel(p);
511 
512             for (int j = 0; j < h; j++)
513             {
514                 for (size_t i = 0; i < top_blobs.size(); i++)
515                 {
516                     Mat& top_blob = top_blobs[i];
517 
518                     float* outptr = top_blob.channel(p).row(j);
519                     memcpy(outptr, ptr, top_blob.w * elemsize);
520 
521                     ptr += top_blob.w * elempack;
522                 }
523             }
524         }
525     }
526 
527     return 0;
528 }
529 
530 } // namespace ncnn
531