1 // SPDX-License-Identifier: Apache-2.0
2 //
3 // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
4 // Copyright 2008-2016 National ICT Australia (NICTA)
5 //
6 // Licensed under the Apache License, Version 2.0 (the "License");
7 // you may not use this file except in compliance with the License.
8 // You may obtain a copy of the License at
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // ------------------------------------------------------------------------
17
18
19 //! \addtogroup glue_join
20 //! @{
21
22
23
24 template<typename T1, typename T2>
25 inline
26 void
apply_noalias(Mat<typename T1::elem_type> & out,const Proxy<T1> & A,const Proxy<T2> & B)27 glue_join_cols::apply_noalias(Mat<typename T1::elem_type>& out, const Proxy<T1>& A, const Proxy<T2>& B)
28 {
29 arma_extra_debug_sigprint();
30
31 const uword A_n_rows = A.get_n_rows();
32 const uword A_n_cols = A.get_n_cols();
33
34 const uword B_n_rows = B.get_n_rows();
35 const uword B_n_cols = B.get_n_cols();
36
37 arma_debug_check
38 (
39 ( (A_n_cols != B_n_cols) && ( (A_n_rows > 0) || (A_n_cols > 0) ) && ( (B_n_rows > 0) || (B_n_cols > 0) ) ),
40 "join_cols() / join_vert(): number of columns must be the same"
41 );
42
43 out.set_size( A_n_rows + B_n_rows, (std::max)(A_n_cols, B_n_cols) );
44
45 if( out.n_elem > 0 )
46 {
47 if(A.get_n_elem() > 0)
48 {
49 out.submat(0, 0, A_n_rows-1, out.n_cols-1) = A.Q;
50 }
51
52 if(B.get_n_elem() > 0)
53 {
54 out.submat(A_n_rows, 0, out.n_rows-1, out.n_cols-1) = B.Q;
55 }
56 }
57 }
58
59
60
61
62 template<typename T1, typename T2>
63 inline
64 void
apply(Mat<typename T1::elem_type> & out,const Glue<T1,T2,glue_join_cols> & X)65 glue_join_cols::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_join_cols>& X)
66 {
67 arma_extra_debug_sigprint();
68
69 typedef typename T1::elem_type eT;
70
71 const Proxy<T1> A(X.A);
72 const Proxy<T2> B(X.B);
73
74 if( (A.is_alias(out) == false) && (B.is_alias(out) == false) )
75 {
76 glue_join_cols::apply_noalias(out, A, B);
77 }
78 else
79 {
80 Mat<eT> tmp;
81
82 glue_join_cols::apply_noalias(tmp, A, B);
83
84 out.steal_mem(tmp);
85 }
86 }
87
88
89
90 template<typename eT, typename T1, typename T2, typename T3>
91 inline
92 void
apply(Mat<eT> & out,const Base<eT,T1> & A_expr,const Base<eT,T2> & B_expr,const Base<eT,T3> & C_expr)93 glue_join_cols::apply(Mat<eT>& out, const Base<eT,T1>& A_expr, const Base<eT,T2>& B_expr, const Base<eT,T3>& C_expr)
94 {
95 arma_extra_debug_sigprint();
96
97 const quasi_unwrap<T1> UA(A_expr.get_ref());
98 const quasi_unwrap<T2> UB(B_expr.get_ref());
99 const quasi_unwrap<T3> UC(C_expr.get_ref());
100
101 const Mat<eT>& A = UA.M;
102 const Mat<eT>& B = UB.M;
103 const Mat<eT>& C = UC.M;
104
105 const uword out_n_rows = A.n_rows + B.n_rows + C.n_rows;
106 const uword out_n_cols = (std::max)((std::max)(A.n_cols, B.n_cols), C.n_cols);
107
108 arma_debug_check( ((A.n_cols != out_n_cols) && ((A.n_rows > 0) || (A.n_cols > 0))), "join_cols() / join_vert(): number of columns must be the same" );
109 arma_debug_check( ((B.n_cols != out_n_cols) && ((B.n_rows > 0) || (B.n_cols > 0))), "join_cols() / join_vert(): number of columns must be the same" );
110 arma_debug_check( ((C.n_cols != out_n_cols) && ((C.n_rows > 0) || (C.n_cols > 0))), "join_cols() / join_vert(): number of columns must be the same" );
111
112 out.set_size(out_n_rows, out_n_cols);
113
114 if(out.n_elem == 0) { return; }
115
116 uword row_start = 0;
117 uword row_end_p1 = 0;
118
119 if(A.n_elem > 0) { row_end_p1 += A.n_rows; out.rows(row_start, row_end_p1 - 1) = A; }
120
121 row_start = row_end_p1;
122
123 if(B.n_elem > 0) { row_end_p1 += B.n_rows; out.rows(row_start, row_end_p1 - 1) = B; }
124
125 row_start = row_end_p1;
126
127 if(C.n_elem > 0) { row_end_p1 += C.n_rows; out.rows(row_start, row_end_p1 - 1) = C; }
128 }
129
130
131
132 template<typename eT, typename T1, typename T2, typename T3, typename T4>
133 inline
134 void
apply(Mat<eT> & out,const Base<eT,T1> & A_expr,const Base<eT,T2> & B_expr,const Base<eT,T3> & C_expr,const Base<eT,T4> & D_expr)135 glue_join_cols::apply(Mat<eT>& out, const Base<eT,T1>& A_expr, const Base<eT,T2>& B_expr, const Base<eT,T3>& C_expr, const Base<eT,T4>& D_expr)
136 {
137 arma_extra_debug_sigprint();
138
139 const quasi_unwrap<T1> UA(A_expr.get_ref());
140 const quasi_unwrap<T2> UB(B_expr.get_ref());
141 const quasi_unwrap<T3> UC(C_expr.get_ref());
142 const quasi_unwrap<T4> UD(D_expr.get_ref());
143
144 const Mat<eT>& A = UA.M;
145 const Mat<eT>& B = UB.M;
146 const Mat<eT>& C = UC.M;
147 const Mat<eT>& D = UD.M;
148
149 const uword out_n_rows = A.n_rows + B.n_rows + C.n_rows + D.n_rows;
150 const uword out_n_cols = (std::max)(((std::max)((std::max)(A.n_cols, B.n_cols), C.n_cols)), D.n_cols);
151
152 arma_debug_check( ((A.n_cols != out_n_cols) && ((A.n_rows > 0) || (A.n_cols > 0))), "join_cols() / join_vert(): number of columns must be the same" );
153 arma_debug_check( ((B.n_cols != out_n_cols) && ((B.n_rows > 0) || (B.n_cols > 0))), "join_cols() / join_vert(): number of columns must be the same" );
154 arma_debug_check( ((C.n_cols != out_n_cols) && ((C.n_rows > 0) || (C.n_cols > 0))), "join_cols() / join_vert(): number of columns must be the same" );
155 arma_debug_check( ((D.n_cols != out_n_cols) && ((D.n_rows > 0) || (D.n_cols > 0))), "join_cols() / join_vert(): number of columns must be the same" );
156
157 out.set_size(out_n_rows, out_n_cols);
158
159 if(out.n_elem == 0) { return; }
160
161 uword row_start = 0;
162 uword row_end_p1 = 0;
163
164 if(A.n_elem > 0) { row_end_p1 += A.n_rows; out.rows(row_start, row_end_p1 - 1) = A; }
165
166 row_start = row_end_p1;
167
168 if(B.n_elem > 0) { row_end_p1 += B.n_rows; out.rows(row_start, row_end_p1 - 1) = B; }
169
170 row_start = row_end_p1;
171
172 if(C.n_elem > 0) { row_end_p1 += C.n_rows; out.rows(row_start, row_end_p1 - 1) = C; }
173
174 row_start = row_end_p1;
175
176 if(D.n_elem > 0) { row_end_p1 += D.n_rows; out.rows(row_start, row_end_p1 - 1) = D; }
177 }
178
179
180
181 template<typename T1, typename T2>
182 inline
183 void
apply_noalias(Mat<typename T1::elem_type> & out,const Proxy<T1> & A,const Proxy<T2> & B)184 glue_join_rows::apply_noalias(Mat<typename T1::elem_type>& out, const Proxy<T1>& A, const Proxy<T2>& B)
185 {
186 arma_extra_debug_sigprint();
187
188 const uword A_n_rows = A.get_n_rows();
189 const uword A_n_cols = A.get_n_cols();
190
191 const uword B_n_rows = B.get_n_rows();
192 const uword B_n_cols = B.get_n_cols();
193
194 arma_debug_check
195 (
196 ( (A_n_rows != B_n_rows) && ( (A_n_rows > 0) || (A_n_cols > 0) ) && ( (B_n_rows > 0) || (B_n_cols > 0) ) ),
197 "join_rows() / join_horiz(): number of rows must be the same"
198 );
199
200 out.set_size( (std::max)(A_n_rows, B_n_rows), A_n_cols + B_n_cols );
201
202 if( out.n_elem > 0 )
203 {
204 if(A.get_n_elem() > 0)
205 {
206 out.submat(0, 0, out.n_rows-1, A_n_cols-1) = A.Q;
207 }
208
209 if(B.get_n_elem() > 0)
210 {
211 out.submat(0, A_n_cols, out.n_rows-1, out.n_cols-1) = B.Q;
212 }
213 }
214 }
215
216
217
218
219 template<typename T1, typename T2>
220 inline
221 void
apply(Mat<typename T1::elem_type> & out,const Glue<T1,T2,glue_join_rows> & X)222 glue_join_rows::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_join_rows>& X)
223 {
224 arma_extra_debug_sigprint();
225
226 typedef typename T1::elem_type eT;
227
228 const Proxy<T1> A(X.A);
229 const Proxy<T2> B(X.B);
230
231 if( (A.is_alias(out) == false) && (B.is_alias(out) == false) )
232 {
233 glue_join_rows::apply_noalias(out, A, B);
234 }
235 else
236 {
237 Mat<eT> tmp;
238
239 glue_join_rows::apply_noalias(tmp, A, B);
240
241 out.steal_mem(tmp);
242 }
243 }
244
245
246
247 template<typename eT, typename T1, typename T2, typename T3>
248 inline
249 void
apply(Mat<eT> & out,const Base<eT,T1> & A_expr,const Base<eT,T2> & B_expr,const Base<eT,T3> & C_expr)250 glue_join_rows::apply(Mat<eT>& out, const Base<eT,T1>& A_expr, const Base<eT,T2>& B_expr, const Base<eT,T3>& C_expr)
251 {
252 arma_extra_debug_sigprint();
253
254 const quasi_unwrap<T1> UA(A_expr.get_ref());
255 const quasi_unwrap<T2> UB(B_expr.get_ref());
256 const quasi_unwrap<T3> UC(C_expr.get_ref());
257
258 const Mat<eT>& A = UA.M;
259 const Mat<eT>& B = UB.M;
260 const Mat<eT>& C = UC.M;
261
262 const uword out_n_rows = (std::max)((std::max)(A.n_rows, B.n_rows), C.n_rows);
263 const uword out_n_cols = A.n_cols + B.n_cols + C.n_cols;
264
265 arma_debug_check( ((A.n_rows != out_n_rows) && ((A.n_rows > 0) || (A.n_cols > 0))), "join_rows() / join_horiz(): number of rows must be the same" );
266 arma_debug_check( ((B.n_rows != out_n_rows) && ((B.n_rows > 0) || (B.n_cols > 0))), "join_rows() / join_horiz(): number of rows must be the same" );
267 arma_debug_check( ((C.n_rows != out_n_rows) && ((C.n_rows > 0) || (C.n_cols > 0))), "join_rows() / join_horiz(): number of rows must be the same" );
268
269 out.set_size(out_n_rows, out_n_cols);
270
271 if(out.n_elem == 0) { return; }
272
273 uword col_start = 0;
274 uword col_end_p1 = 0;
275
276 if(A.n_elem > 0) { col_end_p1 += A.n_cols; out.cols(col_start, col_end_p1 - 1) = A; }
277
278 col_start = col_end_p1;
279
280 if(B.n_elem > 0) { col_end_p1 += B.n_cols; out.cols(col_start, col_end_p1 - 1) = B; }
281
282 col_start = col_end_p1;
283
284 if(C.n_elem > 0) { col_end_p1 += C.n_cols; out.cols(col_start, col_end_p1 - 1) = C; }
285 }
286
287
288
289 template<typename eT, typename T1, typename T2, typename T3, typename T4>
290 inline
291 void
apply(Mat<eT> & out,const Base<eT,T1> & A_expr,const Base<eT,T2> & B_expr,const Base<eT,T3> & C_expr,const Base<eT,T4> & D_expr)292 glue_join_rows::apply(Mat<eT>& out, const Base<eT,T1>& A_expr, const Base<eT,T2>& B_expr, const Base<eT,T3>& C_expr, const Base<eT,T4>& D_expr)
293 {
294 arma_extra_debug_sigprint();
295
296 const quasi_unwrap<T1> UA(A_expr.get_ref());
297 const quasi_unwrap<T2> UB(B_expr.get_ref());
298 const quasi_unwrap<T3> UC(C_expr.get_ref());
299 const quasi_unwrap<T4> UD(D_expr.get_ref());
300
301 const Mat<eT>& A = UA.M;
302 const Mat<eT>& B = UB.M;
303 const Mat<eT>& C = UC.M;
304 const Mat<eT>& D = UD.M;
305
306 const uword out_n_rows = (std::max)(((std::max)((std::max)(A.n_rows, B.n_rows), C.n_rows)), D.n_rows);
307 const uword out_n_cols = A.n_cols + B.n_cols + C.n_cols + D.n_cols;
308
309 arma_debug_check( ((A.n_rows != out_n_rows) && ((A.n_rows > 0) || (A.n_cols > 0))), "join_rows() / join_horiz(): number of rows must be the same" );
310 arma_debug_check( ((B.n_rows != out_n_rows) && ((B.n_rows > 0) || (B.n_cols > 0))), "join_rows() / join_horiz(): number of rows must be the same" );
311 arma_debug_check( ((C.n_rows != out_n_rows) && ((C.n_rows > 0) || (C.n_cols > 0))), "join_rows() / join_horiz(): number of rows must be the same" );
312 arma_debug_check( ((D.n_rows != out_n_rows) && ((D.n_rows > 0) || (D.n_cols > 0))), "join_rows() / join_horiz(): number of rows must be the same" );
313
314 out.set_size(out_n_rows, out_n_cols);
315
316 if(out.n_elem == 0) { return; }
317
318 uword col_start = 0;
319 uword col_end_p1 = 0;
320
321 if(A.n_elem > 0) { col_end_p1 += A.n_cols; out.cols(col_start, col_end_p1 - 1) = A; }
322
323 col_start = col_end_p1;
324
325 if(B.n_elem > 0) { col_end_p1 += B.n_cols; out.cols(col_start, col_end_p1 - 1) = B; }
326
327 col_start = col_end_p1;
328
329 if(C.n_elem > 0) { col_end_p1 += C.n_cols; out.cols(col_start, col_end_p1 - 1) = C; }
330
331 col_start = col_end_p1;
332
333 if(D.n_elem > 0) { col_end_p1 += D.n_cols; out.cols(col_start, col_end_p1 - 1) = D; }
334 }
335
336
337
338 template<typename T1, typename T2>
339 inline
340 void
apply(Cube<typename T1::elem_type> & out,const GlueCube<T1,T2,glue_join_slices> & X)341 glue_join_slices::apply(Cube<typename T1::elem_type>& out, const GlueCube<T1,T2,glue_join_slices>& X)
342 {
343 arma_extra_debug_sigprint();
344
345 typedef typename T1::elem_type eT;
346
347 const unwrap_cube<T1> A_tmp(X.A);
348 const unwrap_cube<T2> B_tmp(X.B);
349
350 const Cube<eT>& A = A_tmp.M;
351 const Cube<eT>& B = B_tmp.M;
352
353 if(A.n_elem == 0) { out = B; return; }
354 if(B.n_elem == 0) { out = A; return; }
355
356 arma_debug_check( ( (A.n_rows != B.n_rows) || (A.n_cols != B.n_cols) ), "join_slices(): size of slices must be the same" );
357
358 if( (&out != &A) && (&out != &B) )
359 {
360 out.set_size(A.n_rows, A.n_cols, A.n_slices + B.n_slices);
361
362 out.slices(0, A.n_slices-1 ) = A;
363 out.slices(A.n_slices, out.n_slices-1) = B;
364 }
365 else // we have aliasing
366 {
367 Cube<eT> C(A.n_rows, A.n_cols, A.n_slices + B.n_slices, arma_nozeros_indicator());
368
369 C.slices(0, A.n_slices-1) = A;
370 C.slices(A.n_slices, C.n_slices-1) = B;
371
372 out.steal_mem(C);
373 }
374
375 }
376
377
378
379 //! @}
380