1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
15 #include "slice.h"
16 
17 namespace ncnn {
18 
Slice()19 Slice::Slice()
20 {
21 }
22 
load_param(const ParamDict & pd)23 int Slice::load_param(const ParamDict& pd)
24 {
25     slices = pd.get(0, Mat());
26     axis = pd.get(1, 0);
27 
28     return 0;
29 }
30 
forward(const std::vector<Mat> & bottom_blobs,std::vector<Mat> & top_blobs,const Option & opt) const31 int Slice::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
32 {
33     const Mat& bottom_blob = bottom_blobs[0];
34     int dims = bottom_blob.dims;
35     size_t elemsize = bottom_blob.elemsize;
36     const int* slices_ptr = slices;
37     int positive_axis = axis < 0 ? dims + axis : axis;
38 
39     if (dims == 1) // positive_axis == 0
40     {
41         int w = bottom_blob.w;
42 
43         int q = 0;
44         for (size_t i = 0; i < top_blobs.size(); i++)
45         {
46             int slice = slices_ptr[i];
47             if (slice == -233)
48             {
49                 slice = static_cast<int>((w - q) / (top_blobs.size() - i));
50             }
51 
52             Mat& top_blob = top_blobs[i];
53             top_blob.create(slice, elemsize, opt.blob_allocator);
54             if (top_blob.empty())
55                 return -100;
56 
57             const unsigned char* ptr = (const unsigned char*)bottom_blob + q * elemsize;
58             unsigned char* outptr = top_blob;
59             memcpy(outptr, ptr, slice * elemsize);
60 
61             q += slice;
62         }
63 
64         return 0;
65     }
66 
67     if (dims == 2 && positive_axis == 0)
68     {
69         int w = bottom_blob.w;
70         int h = bottom_blob.h;
71 
72         int q = 0;
73         for (size_t i = 0; i < top_blobs.size(); i++)
74         {
75             int slice = slices_ptr[i];
76             if (slice == -233)
77             {
78                 slice = static_cast<int>((h - q) / (top_blobs.size() - i));
79             }
80 
81             Mat& top_blob = top_blobs[i];
82             top_blob.create(w, slice, elemsize, opt.blob_allocator);
83             if (top_blob.empty())
84                 return -100;
85 
86             int size = w * slice;
87 
88             const unsigned char* ptr = bottom_blob.row<const unsigned char>(q);
89             unsigned char* outptr = top_blob;
90             memcpy(outptr, ptr, size * elemsize);
91 
92             q += slice;
93         }
94 
95         return 0;
96     }
97 
98     if (dims == 2 && positive_axis == 1)
99     {
100         int w = bottom_blob.w;
101         int h = bottom_blob.h;
102 
103         int q = 0;
104         for (size_t i = 0; i < top_blobs.size(); i++)
105         {
106             int slice = slices_ptr[i];
107             if (slice == -233)
108             {
109                 slice = static_cast<int>((w - q) / (top_blobs.size() - i));
110             }
111 
112             Mat& top_blob = top_blobs[i];
113             top_blob.create(slice, h, elemsize, opt.blob_allocator);
114             if (top_blob.empty())
115                 return -100;
116 
117             #pragma omp parallel for num_threads(opt.num_threads)
118             for (int j = 0; j < h; j++)
119             {
120                 unsigned char* outptr = top_blob.row<unsigned char>(j);
121                 const unsigned char* ptr = bottom_blob.row<const unsigned char>(j) + q * elemsize;
122                 memcpy(outptr, ptr, slice * elemsize);
123             }
124 
125             q += slice;
126         }
127 
128         return 0;
129     }
130 
131     if (dims == 3 && positive_axis == 0)
132     {
133         int w = bottom_blob.w;
134         int h = bottom_blob.h;
135         int channels = bottom_blob.c;
136 
137         int q = 0;
138         for (size_t i = 0; i < top_blobs.size(); i++)
139         {
140             int slice = slices_ptr[i];
141             if (slice == -233)
142             {
143                 slice = static_cast<int>((channels - q) / (top_blobs.size() - i));
144             }
145 
146             Mat& top_blob = top_blobs[i];
147             top_blob.create(w, h, slice, elemsize, opt.blob_allocator);
148             if (top_blob.empty())
149                 return -100;
150 
151             int size = static_cast<int>(bottom_blob.cstep * slice);
152 
153             const unsigned char* ptr = bottom_blob.channel(q);
154             unsigned char* outptr = top_blob;
155             memcpy(outptr, ptr, size * elemsize);
156 
157             q += slice;
158         }
159 
160         return 0;
161     }
162 
163     if (dims == 3 && positive_axis == 1)
164     {
165         int w = bottom_blob.w;
166         int h = bottom_blob.h;
167         int channels = bottom_blob.c;
168 
169         int q = 0;
170         for (size_t i = 0; i < top_blobs.size(); i++)
171         {
172             int slice = slices_ptr[i];
173             if (slice == -233)
174             {
175                 slice = static_cast<int>((h - q) / (top_blobs.size() - i));
176             }
177 
178             Mat& top_blob = top_blobs[i];
179             top_blob.create(w, slice, channels, elemsize, opt.blob_allocator);
180             if (top_blob.empty())
181                 return -100;
182 
183             #pragma omp parallel for num_threads(opt.num_threads)
184             for (int p = 0; p < channels; p++)
185             {
186                 int size = w * slice;
187 
188                 unsigned char* outptr = top_blob.channel(p);
189                 const unsigned char* ptr = bottom_blob.channel(p).row<const unsigned char>(q);
190                 memcpy(outptr, ptr, size * elemsize);
191             }
192 
193             q += slice;
194         }
195 
196         return 0;
197     }
198 
199     if (dims == 3 && positive_axis == 2)
200     {
201         int w = bottom_blob.w;
202         int h = bottom_blob.h;
203         int channels = bottom_blob.c;
204 
205         int q = 0;
206         for (size_t i = 0; i < top_blobs.size(); i++)
207         {
208             int slice = slices_ptr[i];
209             if (slice == -233)
210             {
211                 slice = static_cast<int>((w - q) / (top_blobs.size() - i));
212             }
213 
214             Mat& top_blob = top_blobs[i];
215             top_blob.create(slice, h, channels, elemsize, opt.blob_allocator);
216             if (top_blob.empty())
217                 return -100;
218 
219             #pragma omp parallel for num_threads(opt.num_threads)
220             for (int p = 0; p < channels; p++)
221             {
222                 unsigned char* outptr = top_blob.channel(p);
223                 const Mat m = bottom_blob.channel(p);
224 
225                 for (int j = 0; j < h; j++)
226                 {
227                     const unsigned char* ptr = m.row<const unsigned char>(j) + q * elemsize;
228                     memcpy(outptr, ptr, slice * elemsize);
229 
230                     outptr += slice * elemsize;
231                 }
232             }
233 
234             q += slice;
235         }
236 
237         return 0;
238     }
239 
240     return 0;
241 }
242 
243 } // namespace ncnn
244