1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
15 #include "slice.h"
16
17 namespace ncnn {
18
Slice()19 Slice::Slice()
20 {
21 }
22
load_param(const ParamDict & pd)23 int Slice::load_param(const ParamDict& pd)
24 {
25 slices = pd.get(0, Mat());
26 axis = pd.get(1, 0);
27
28 return 0;
29 }
30
forward(const std::vector<Mat> & bottom_blobs,std::vector<Mat> & top_blobs,const Option & opt) const31 int Slice::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
32 {
33 const Mat& bottom_blob = bottom_blobs[0];
34 int dims = bottom_blob.dims;
35 size_t elemsize = bottom_blob.elemsize;
36 const int* slices_ptr = slices;
37 int positive_axis = axis < 0 ? dims + axis : axis;
38
39 if (dims == 1) // positive_axis == 0
40 {
41 int w = bottom_blob.w;
42
43 int q = 0;
44 for (size_t i = 0; i < top_blobs.size(); i++)
45 {
46 int slice = slices_ptr[i];
47 if (slice == -233)
48 {
49 slice = static_cast<int>((w - q) / (top_blobs.size() - i));
50 }
51
52 Mat& top_blob = top_blobs[i];
53 top_blob.create(slice, elemsize, opt.blob_allocator);
54 if (top_blob.empty())
55 return -100;
56
57 const unsigned char* ptr = (const unsigned char*)bottom_blob + q * elemsize;
58 unsigned char* outptr = top_blob;
59 memcpy(outptr, ptr, slice * elemsize);
60
61 q += slice;
62 }
63
64 return 0;
65 }
66
67 if (dims == 2 && positive_axis == 0)
68 {
69 int w = bottom_blob.w;
70 int h = bottom_blob.h;
71
72 int q = 0;
73 for (size_t i = 0; i < top_blobs.size(); i++)
74 {
75 int slice = slices_ptr[i];
76 if (slice == -233)
77 {
78 slice = static_cast<int>((h - q) / (top_blobs.size() - i));
79 }
80
81 Mat& top_blob = top_blobs[i];
82 top_blob.create(w, slice, elemsize, opt.blob_allocator);
83 if (top_blob.empty())
84 return -100;
85
86 int size = w * slice;
87
88 const unsigned char* ptr = bottom_blob.row<const unsigned char>(q);
89 unsigned char* outptr = top_blob;
90 memcpy(outptr, ptr, size * elemsize);
91
92 q += slice;
93 }
94
95 return 0;
96 }
97
98 if (dims == 2 && positive_axis == 1)
99 {
100 int w = bottom_blob.w;
101 int h = bottom_blob.h;
102
103 int q = 0;
104 for (size_t i = 0; i < top_blobs.size(); i++)
105 {
106 int slice = slices_ptr[i];
107 if (slice == -233)
108 {
109 slice = static_cast<int>((w - q) / (top_blobs.size() - i));
110 }
111
112 Mat& top_blob = top_blobs[i];
113 top_blob.create(slice, h, elemsize, opt.blob_allocator);
114 if (top_blob.empty())
115 return -100;
116
117 #pragma omp parallel for num_threads(opt.num_threads)
118 for (int j = 0; j < h; j++)
119 {
120 unsigned char* outptr = top_blob.row<unsigned char>(j);
121 const unsigned char* ptr = bottom_blob.row<const unsigned char>(j) + q * elemsize;
122 memcpy(outptr, ptr, slice * elemsize);
123 }
124
125 q += slice;
126 }
127
128 return 0;
129 }
130
131 if (dims == 3 && positive_axis == 0)
132 {
133 int w = bottom_blob.w;
134 int h = bottom_blob.h;
135 int channels = bottom_blob.c;
136
137 int q = 0;
138 for (size_t i = 0; i < top_blobs.size(); i++)
139 {
140 int slice = slices_ptr[i];
141 if (slice == -233)
142 {
143 slice = static_cast<int>((channels - q) / (top_blobs.size() - i));
144 }
145
146 Mat& top_blob = top_blobs[i];
147 top_blob.create(w, h, slice, elemsize, opt.blob_allocator);
148 if (top_blob.empty())
149 return -100;
150
151 int size = static_cast<int>(bottom_blob.cstep * slice);
152
153 const unsigned char* ptr = bottom_blob.channel(q);
154 unsigned char* outptr = top_blob;
155 memcpy(outptr, ptr, size * elemsize);
156
157 q += slice;
158 }
159
160 return 0;
161 }
162
163 if (dims == 3 && positive_axis == 1)
164 {
165 int w = bottom_blob.w;
166 int h = bottom_blob.h;
167 int channels = bottom_blob.c;
168
169 int q = 0;
170 for (size_t i = 0; i < top_blobs.size(); i++)
171 {
172 int slice = slices_ptr[i];
173 if (slice == -233)
174 {
175 slice = static_cast<int>((h - q) / (top_blobs.size() - i));
176 }
177
178 Mat& top_blob = top_blobs[i];
179 top_blob.create(w, slice, channels, elemsize, opt.blob_allocator);
180 if (top_blob.empty())
181 return -100;
182
183 #pragma omp parallel for num_threads(opt.num_threads)
184 for (int p = 0; p < channels; p++)
185 {
186 int size = w * slice;
187
188 unsigned char* outptr = top_blob.channel(p);
189 const unsigned char* ptr = bottom_blob.channel(p).row<const unsigned char>(q);
190 memcpy(outptr, ptr, size * elemsize);
191 }
192
193 q += slice;
194 }
195
196 return 0;
197 }
198
199 if (dims == 3 && positive_axis == 2)
200 {
201 int w = bottom_blob.w;
202 int h = bottom_blob.h;
203 int channels = bottom_blob.c;
204
205 int q = 0;
206 for (size_t i = 0; i < top_blobs.size(); i++)
207 {
208 int slice = slices_ptr[i];
209 if (slice == -233)
210 {
211 slice = static_cast<int>((w - q) / (top_blobs.size() - i));
212 }
213
214 Mat& top_blob = top_blobs[i];
215 top_blob.create(slice, h, channels, elemsize, opt.blob_allocator);
216 if (top_blob.empty())
217 return -100;
218
219 #pragma omp parallel for num_threads(opt.num_threads)
220 for (int p = 0; p < channels; p++)
221 {
222 unsigned char* outptr = top_blob.channel(p);
223 const Mat m = bottom_blob.channel(p);
224
225 for (int j = 0; j < h; j++)
226 {
227 const unsigned char* ptr = m.row<const unsigned char>(j) + q * elemsize;
228 memcpy(outptr, ptr, slice * elemsize);
229
230 outptr += slice * elemsize;
231 }
232 }
233
234 q += slice;
235 }
236
237 return 0;
238 }
239
240 return 0;
241 }
242
243 } // namespace ncnn
244