1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
15 #include "concat.h"
16
17 namespace ncnn {
18
Concat()19 Concat::Concat()
20 {
21 one_blob_only = false;
22 support_inplace = false;
23 }
24
load_param(const ParamDict & pd)25 int Concat::load_param(const ParamDict& pd)
26 {
27 axis = pd.get(0, 0);
28
29 return 0;
30 }
31
forward(const std::vector<Mat> & bottom_blobs,std::vector<Mat> & top_blobs,const Option & opt) const32 int Concat::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
33 {
34 int dims = bottom_blobs[0].dims;
35 size_t elemsize = bottom_blobs[0].elemsize;
36 int positive_axis = axis < 0 ? dims + axis : axis;
37
38 if (dims == 1) // positive_axis == 0
39 {
40 // concat vector
41 // total length
42 int top_w = 0;
43 for (size_t b = 0; b < bottom_blobs.size(); b++)
44 {
45 const Mat& bottom_blob = bottom_blobs[b];
46 top_w += bottom_blob.w;
47 }
48
49 Mat& top_blob = top_blobs[0];
50 top_blob.create(top_w, elemsize, opt.blob_allocator);
51 if (top_blob.empty())
52 return -100;
53
54 unsigned char* outptr = top_blob;
55 for (size_t b = 0; b < bottom_blobs.size(); b++)
56 {
57 const Mat& bottom_blob = bottom_blobs[b];
58
59 int w = bottom_blob.w;
60
61 const unsigned char* ptr = bottom_blob;
62 memcpy(outptr, ptr, w * elemsize);
63
64 outptr += w * elemsize;
65 }
66
67 return 0;
68 }
69
70 if (dims == 2 && positive_axis == 0)
71 {
72 // concat image
73 int w = bottom_blobs[0].w;
74
75 // total height
76 int top_h = 0;
77 for (size_t b = 0; b < bottom_blobs.size(); b++)
78 {
79 const Mat& bottom_blob = bottom_blobs[b];
80 top_h += bottom_blob.h;
81 }
82
83 Mat& top_blob = top_blobs[0];
84 top_blob.create(w, top_h, elemsize, opt.blob_allocator);
85 if (top_blob.empty())
86 return -100;
87
88 unsigned char* outptr = top_blob;
89 for (size_t b = 0; b < bottom_blobs.size(); b++)
90 {
91 const Mat& bottom_blob = bottom_blobs[b];
92
93 int size = w * bottom_blob.h;
94
95 const unsigned char* ptr = bottom_blob;
96 memcpy(outptr, ptr, size * elemsize);
97
98 outptr += size * elemsize;
99 }
100
101 return 0;
102 }
103
104 if (dims == 2 && positive_axis == 1)
105 {
106 // interleave image row
107 int h = bottom_blobs[0].h;
108
109 // total width
110 int top_w = 0;
111 for (size_t b = 0; b < bottom_blobs.size(); b++)
112 {
113 const Mat& bottom_blob = bottom_blobs[b];
114 top_w += bottom_blob.w;
115 }
116
117 Mat& top_blob = top_blobs[0];
118 top_blob.create(top_w, h, elemsize, opt.blob_allocator);
119 if (top_blob.empty())
120 return -100;
121
122 #pragma omp parallel for num_threads(opt.num_threads)
123 for (int i = 0; i < h; i++)
124 {
125 unsigned char* outptr = top_blob.row<unsigned char>(i);
126 for (size_t b = 0; b < bottom_blobs.size(); b++)
127 {
128 const Mat& bottom_blob = bottom_blobs[b];
129
130 const unsigned char* ptr = bottom_blob.row<const unsigned char>(i);
131 memcpy(outptr, ptr, bottom_blob.w * elemsize);
132
133 outptr += bottom_blob.w * elemsize;
134 }
135 }
136
137 return 0;
138 }
139
140 if (dims == 3 && positive_axis == 0)
141 {
142 // concat dim
143 int w = bottom_blobs[0].w;
144 int h = bottom_blobs[0].h;
145
146 // total channels
147 int top_channels = 0;
148 for (size_t b = 0; b < bottom_blobs.size(); b++)
149 {
150 const Mat& bottom_blob = bottom_blobs[b];
151 top_channels += bottom_blob.c;
152 }
153
154 Mat& top_blob = top_blobs[0];
155 top_blob.create(w, h, top_channels, elemsize, opt.blob_allocator);
156 if (top_blob.empty())
157 return -100;
158
159 int q = 0;
160 for (size_t b = 0; b < bottom_blobs.size(); b++)
161 {
162 const Mat& bottom_blob = bottom_blobs[b];
163
164 int channels = bottom_blob.c;
165 size_t size = bottom_blob.cstep * channels;
166
167 const unsigned char* ptr = bottom_blob;
168 unsigned char* outptr = top_blob.channel(q);
169 memcpy(outptr, ptr, size * elemsize);
170
171 q += channels;
172 }
173
174 return 0;
175 }
176
177 if (dims == 3 && positive_axis == 1)
178 {
179 // interleave dim height
180 int w = bottom_blobs[0].w;
181 int channels = bottom_blobs[0].c;
182
183 // total height
184 int top_h = 0;
185 for (size_t b = 0; b < bottom_blobs.size(); b++)
186 {
187 const Mat& bottom_blob = bottom_blobs[b];
188 top_h += bottom_blob.h;
189 }
190
191 Mat& top_blob = top_blobs[0];
192 top_blob.create(w, top_h, channels, elemsize, opt.blob_allocator);
193 if (top_blob.empty())
194 return -100;
195
196 #pragma omp parallel for num_threads(opt.num_threads)
197 for (int q = 0; q < channels; q++)
198 {
199 unsigned char* outptr = top_blob.channel(q);
200
201 for (size_t b = 0; b < bottom_blobs.size(); b++)
202 {
203 const Mat& bottom_blob = bottom_blobs[b];
204
205 int size = bottom_blob.w * bottom_blob.h;
206
207 const unsigned char* ptr = bottom_blob.channel(q);
208 memcpy(outptr, ptr, size * elemsize);
209
210 outptr += size * elemsize;
211 }
212 }
213
214 return 0;
215 }
216
217 if (dims == 3 && positive_axis == 2)
218 {
219 // interleave dim width
220 int h = bottom_blobs[0].h;
221 int channels = bottom_blobs[0].c;
222
223 // total width
224 int top_w = 0;
225 for (size_t b = 0; b < bottom_blobs.size(); b++)
226 {
227 const Mat& bottom_blob = bottom_blobs[b];
228 top_w += bottom_blob.w;
229 }
230
231 Mat& top_blob = top_blobs[0];
232 top_blob.create(top_w, h, channels, elemsize, opt.blob_allocator);
233 if (top_blob.empty())
234 return -100;
235
236 #pragma omp parallel for num_threads(opt.num_threads)
237 for (int q = 0; q < channels; q++)
238 {
239 unsigned char* outptr = top_blob.channel(q);
240
241 for (int i = 0; i < h; i++)
242 {
243 for (size_t b = 0; b < bottom_blobs.size(); b++)
244 {
245 const Mat& bottom_blob = bottom_blobs[b];
246
247 const unsigned char* ptr = bottom_blob.channel(q).row<const unsigned char>(i);
248 memcpy(outptr, ptr, bottom_blob.w * elemsize);
249
250 outptr += bottom_blob.w * elemsize;
251 }
252 }
253 }
254
255 return 0;
256 }
257
258 return 0;
259 }
260
261 } // namespace ncnn
262