1 // SPDX-License-Identifier: Apache-2.0
2 //
3 // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
4 // Copyright 2008-2016 National ICT Australia (NICTA)
5 //
6 // Licensed under the Apache License, Version 2.0 (the "License");
7 // you may not use this file except in compliance with the License.
8 // You may obtain a copy of the License at
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // ------------------------------------------------------------------------
17 
18 
19 //! \addtogroup glue_join
20 //! @{
21 
22 
23 
24 template<typename T1, typename T2>
25 inline
26 void
apply_noalias(Mat<typename T1::elem_type> & out,const Proxy<T1> & A,const Proxy<T2> & B)27 glue_join_cols::apply_noalias(Mat<typename T1::elem_type>& out, const Proxy<T1>& A, const Proxy<T2>& B)
28   {
29   arma_extra_debug_sigprint();
30 
31   const uword A_n_rows = A.get_n_rows();
32   const uword A_n_cols = A.get_n_cols();
33 
34   const uword B_n_rows = B.get_n_rows();
35   const uword B_n_cols = B.get_n_cols();
36 
37   arma_debug_check
38     (
39     ( (A_n_cols != B_n_cols) && ( (A_n_rows > 0) || (A_n_cols > 0) ) && ( (B_n_rows > 0) || (B_n_cols > 0) ) ),
40     "join_cols() / join_vert(): number of columns must be the same"
41     );
42 
43   out.set_size( A_n_rows + B_n_rows, (std::max)(A_n_cols, B_n_cols) );
44 
45   if( out.n_elem > 0 )
46     {
47     if(A.get_n_elem() > 0)
48       {
49       out.submat(0,        0,   A_n_rows-1, out.n_cols-1) = A.Q;
50       }
51 
52     if(B.get_n_elem() > 0)
53       {
54       out.submat(A_n_rows, 0, out.n_rows-1, out.n_cols-1) = B.Q;
55       }
56     }
57   }
58 
59 
60 
61 
62 template<typename T1, typename T2>
63 inline
64 void
apply(Mat<typename T1::elem_type> & out,const Glue<T1,T2,glue_join_cols> & X)65 glue_join_cols::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_join_cols>& X)
66   {
67   arma_extra_debug_sigprint();
68 
69   typedef typename T1::elem_type eT;
70 
71   const Proxy<T1> A(X.A);
72   const Proxy<T2> B(X.B);
73 
74   if( (A.is_alias(out) == false) && (B.is_alias(out) == false) )
75     {
76     glue_join_cols::apply_noalias(out, A, B);
77     }
78   else
79     {
80     Mat<eT> tmp;
81 
82     glue_join_cols::apply_noalias(tmp, A, B);
83 
84     out.steal_mem(tmp);
85     }
86   }
87 
88 
89 
90 template<typename eT, typename T1, typename T2, typename T3>
91 inline
92 void
apply(Mat<eT> & out,const Base<eT,T1> & A_expr,const Base<eT,T2> & B_expr,const Base<eT,T3> & C_expr)93 glue_join_cols::apply(Mat<eT>& out, const Base<eT,T1>& A_expr, const Base<eT,T2>& B_expr, const Base<eT,T3>& C_expr)
94   {
95   arma_extra_debug_sigprint();
96 
97   const quasi_unwrap<T1> UA(A_expr.get_ref());
98   const quasi_unwrap<T2> UB(B_expr.get_ref());
99   const quasi_unwrap<T3> UC(C_expr.get_ref());
100 
101   const Mat<eT>& A = UA.M;
102   const Mat<eT>& B = UB.M;
103   const Mat<eT>& C = UC.M;
104 
105   const uword out_n_rows = A.n_rows + B.n_rows + C.n_rows;
106   const uword out_n_cols = (std::max)((std::max)(A.n_cols, B.n_cols), C.n_cols);
107 
108   arma_debug_check( ((A.n_cols != out_n_cols) && ((A.n_rows > 0) || (A.n_cols > 0))), "join_cols() / join_vert(): number of columns must be the same" );
109   arma_debug_check( ((B.n_cols != out_n_cols) && ((B.n_rows > 0) || (B.n_cols > 0))), "join_cols() / join_vert(): number of columns must be the same" );
110   arma_debug_check( ((C.n_cols != out_n_cols) && ((C.n_rows > 0) || (C.n_cols > 0))), "join_cols() / join_vert(): number of columns must be the same" );
111 
112   out.set_size(out_n_rows, out_n_cols);
113 
114   if(out.n_elem == 0)  { return; }
115 
116   uword row_start  = 0;
117   uword row_end_p1 = 0;
118 
119   if(A.n_elem > 0)  { row_end_p1 += A.n_rows; out.rows(row_start, row_end_p1 - 1) = A; }
120 
121   row_start = row_end_p1;
122 
123   if(B.n_elem > 0)  { row_end_p1 += B.n_rows; out.rows(row_start, row_end_p1 - 1) = B; }
124 
125   row_start = row_end_p1;
126 
127   if(C.n_elem > 0)  { row_end_p1 += C.n_rows; out.rows(row_start, row_end_p1 - 1) = C; }
128   }
129 
130 
131 
132 template<typename eT, typename T1, typename T2, typename T3, typename T4>
133 inline
134 void
apply(Mat<eT> & out,const Base<eT,T1> & A_expr,const Base<eT,T2> & B_expr,const Base<eT,T3> & C_expr,const Base<eT,T4> & D_expr)135 glue_join_cols::apply(Mat<eT>& out, const Base<eT,T1>& A_expr, const Base<eT,T2>& B_expr, const Base<eT,T3>& C_expr, const Base<eT,T4>& D_expr)
136   {
137   arma_extra_debug_sigprint();
138 
139   const quasi_unwrap<T1> UA(A_expr.get_ref());
140   const quasi_unwrap<T2> UB(B_expr.get_ref());
141   const quasi_unwrap<T3> UC(C_expr.get_ref());
142   const quasi_unwrap<T4> UD(D_expr.get_ref());
143 
144   const Mat<eT>& A = UA.M;
145   const Mat<eT>& B = UB.M;
146   const Mat<eT>& C = UC.M;
147   const Mat<eT>& D = UD.M;
148 
149   const uword out_n_rows = A.n_rows + B.n_rows + C.n_rows + D.n_rows;
150   const uword out_n_cols = (std::max)(((std::max)((std::max)(A.n_cols, B.n_cols), C.n_cols)), D.n_cols);
151 
152   arma_debug_check( ((A.n_cols != out_n_cols) && ((A.n_rows > 0) || (A.n_cols > 0))), "join_cols() / join_vert(): number of columns must be the same" );
153   arma_debug_check( ((B.n_cols != out_n_cols) && ((B.n_rows > 0) || (B.n_cols > 0))), "join_cols() / join_vert(): number of columns must be the same" );
154   arma_debug_check( ((C.n_cols != out_n_cols) && ((C.n_rows > 0) || (C.n_cols > 0))), "join_cols() / join_vert(): number of columns must be the same" );
155   arma_debug_check( ((D.n_cols != out_n_cols) && ((D.n_rows > 0) || (D.n_cols > 0))), "join_cols() / join_vert(): number of columns must be the same" );
156 
157   out.set_size(out_n_rows, out_n_cols);
158 
159   if(out.n_elem == 0)  { return; }
160 
161   uword row_start  = 0;
162   uword row_end_p1 = 0;
163 
164   if(A.n_elem > 0)  { row_end_p1 += A.n_rows; out.rows(row_start, row_end_p1 - 1) = A; }
165 
166   row_start = row_end_p1;
167 
168   if(B.n_elem > 0)  { row_end_p1 += B.n_rows; out.rows(row_start, row_end_p1 - 1) = B; }
169 
170   row_start = row_end_p1;
171 
172   if(C.n_elem > 0)  { row_end_p1 += C.n_rows; out.rows(row_start, row_end_p1 - 1) = C; }
173 
174   row_start = row_end_p1;
175 
176   if(D.n_elem > 0)  { row_end_p1 += D.n_rows; out.rows(row_start, row_end_p1 - 1) = D; }
177   }
178 
179 
180 
181 template<typename T1, typename T2>
182 inline
183 void
apply_noalias(Mat<typename T1::elem_type> & out,const Proxy<T1> & A,const Proxy<T2> & B)184 glue_join_rows::apply_noalias(Mat<typename T1::elem_type>& out, const Proxy<T1>& A, const Proxy<T2>& B)
185   {
186   arma_extra_debug_sigprint();
187 
188   const uword A_n_rows = A.get_n_rows();
189   const uword A_n_cols = A.get_n_cols();
190 
191   const uword B_n_rows = B.get_n_rows();
192   const uword B_n_cols = B.get_n_cols();
193 
194   arma_debug_check
195     (
196     ( (A_n_rows != B_n_rows) && ( (A_n_rows > 0) || (A_n_cols > 0) ) && ( (B_n_rows > 0) || (B_n_cols > 0) ) ),
197     "join_rows() / join_horiz(): number of rows must be the same"
198     );
199 
200   out.set_size( (std::max)(A_n_rows, B_n_rows), A_n_cols + B_n_cols );
201 
202   if( out.n_elem > 0 )
203     {
204     if(A.get_n_elem() > 0)
205       {
206       out.submat(0, 0,        out.n_rows-1,   A_n_cols-1) = A.Q;
207       }
208 
209     if(B.get_n_elem() > 0)
210       {
211       out.submat(0, A_n_cols, out.n_rows-1, out.n_cols-1) = B.Q;
212       }
213     }
214   }
215 
216 
217 
218 
219 template<typename T1, typename T2>
220 inline
221 void
apply(Mat<typename T1::elem_type> & out,const Glue<T1,T2,glue_join_rows> & X)222 glue_join_rows::apply(Mat<typename T1::elem_type>& out, const Glue<T1,T2,glue_join_rows>& X)
223   {
224   arma_extra_debug_sigprint();
225 
226   typedef typename T1::elem_type eT;
227 
228   const Proxy<T1> A(X.A);
229   const Proxy<T2> B(X.B);
230 
231   if( (A.is_alias(out) == false) && (B.is_alias(out) == false) )
232     {
233     glue_join_rows::apply_noalias(out, A, B);
234     }
235   else
236     {
237     Mat<eT> tmp;
238 
239     glue_join_rows::apply_noalias(tmp, A, B);
240 
241     out.steal_mem(tmp);
242     }
243   }
244 
245 
246 
247 template<typename eT, typename T1, typename T2, typename T3>
248 inline
249 void
apply(Mat<eT> & out,const Base<eT,T1> & A_expr,const Base<eT,T2> & B_expr,const Base<eT,T3> & C_expr)250 glue_join_rows::apply(Mat<eT>& out, const Base<eT,T1>& A_expr, const Base<eT,T2>& B_expr, const Base<eT,T3>& C_expr)
251   {
252   arma_extra_debug_sigprint();
253 
254   const quasi_unwrap<T1> UA(A_expr.get_ref());
255   const quasi_unwrap<T2> UB(B_expr.get_ref());
256   const quasi_unwrap<T3> UC(C_expr.get_ref());
257 
258   const Mat<eT>& A = UA.M;
259   const Mat<eT>& B = UB.M;
260   const Mat<eT>& C = UC.M;
261 
262   const uword out_n_rows = (std::max)((std::max)(A.n_rows, B.n_rows), C.n_rows);
263   const uword out_n_cols = A.n_cols + B.n_cols + C.n_cols;
264 
265   arma_debug_check( ((A.n_rows != out_n_rows) && ((A.n_rows > 0) || (A.n_cols > 0))), "join_rows() / join_horiz(): number of rows must be the same" );
266   arma_debug_check( ((B.n_rows != out_n_rows) && ((B.n_rows > 0) || (B.n_cols > 0))), "join_rows() / join_horiz(): number of rows must be the same" );
267   arma_debug_check( ((C.n_rows != out_n_rows) && ((C.n_rows > 0) || (C.n_cols > 0))), "join_rows() / join_horiz(): number of rows must be the same" );
268 
269   out.set_size(out_n_rows, out_n_cols);
270 
271   if(out.n_elem == 0)  { return; }
272 
273   uword col_start  = 0;
274   uword col_end_p1 = 0;
275 
276   if(A.n_elem > 0)  { col_end_p1 += A.n_cols; out.cols(col_start, col_end_p1 - 1) = A; }
277 
278   col_start = col_end_p1;
279 
280   if(B.n_elem > 0)  { col_end_p1 += B.n_cols; out.cols(col_start, col_end_p1 - 1) = B; }
281 
282   col_start = col_end_p1;
283 
284   if(C.n_elem > 0)  { col_end_p1 += C.n_cols; out.cols(col_start, col_end_p1 - 1) = C; }
285   }
286 
287 
288 
289 template<typename eT, typename T1, typename T2, typename T3, typename T4>
290 inline
291 void
apply(Mat<eT> & out,const Base<eT,T1> & A_expr,const Base<eT,T2> & B_expr,const Base<eT,T3> & C_expr,const Base<eT,T4> & D_expr)292 glue_join_rows::apply(Mat<eT>& out, const Base<eT,T1>& A_expr, const Base<eT,T2>& B_expr, const Base<eT,T3>& C_expr, const Base<eT,T4>& D_expr)
293   {
294   arma_extra_debug_sigprint();
295 
296   const quasi_unwrap<T1> UA(A_expr.get_ref());
297   const quasi_unwrap<T2> UB(B_expr.get_ref());
298   const quasi_unwrap<T3> UC(C_expr.get_ref());
299   const quasi_unwrap<T4> UD(D_expr.get_ref());
300 
301   const Mat<eT>& A = UA.M;
302   const Mat<eT>& B = UB.M;
303   const Mat<eT>& C = UC.M;
304   const Mat<eT>& D = UD.M;
305 
306   const uword out_n_rows = (std::max)(((std::max)((std::max)(A.n_rows, B.n_rows), C.n_rows)), D.n_rows);
307   const uword out_n_cols = A.n_cols + B.n_cols + C.n_cols + D.n_cols;
308 
309   arma_debug_check( ((A.n_rows != out_n_rows) && ((A.n_rows > 0) || (A.n_cols > 0))), "join_rows() / join_horiz(): number of rows must be the same" );
310   arma_debug_check( ((B.n_rows != out_n_rows) && ((B.n_rows > 0) || (B.n_cols > 0))), "join_rows() / join_horiz(): number of rows must be the same" );
311   arma_debug_check( ((C.n_rows != out_n_rows) && ((C.n_rows > 0) || (C.n_cols > 0))), "join_rows() / join_horiz(): number of rows must be the same" );
312   arma_debug_check( ((D.n_rows != out_n_rows) && ((D.n_rows > 0) || (D.n_cols > 0))), "join_rows() / join_horiz(): number of rows must be the same" );
313 
314   out.set_size(out_n_rows, out_n_cols);
315 
316   if(out.n_elem == 0)  { return; }
317 
318   uword col_start  = 0;
319   uword col_end_p1 = 0;
320 
321   if(A.n_elem > 0)  { col_end_p1 += A.n_cols; out.cols(col_start, col_end_p1 - 1) = A; }
322 
323   col_start = col_end_p1;
324 
325   if(B.n_elem > 0)  { col_end_p1 += B.n_cols; out.cols(col_start, col_end_p1 - 1) = B; }
326 
327   col_start = col_end_p1;
328 
329   if(C.n_elem > 0)  { col_end_p1 += C.n_cols; out.cols(col_start, col_end_p1 - 1) = C; }
330 
331   col_start = col_end_p1;
332 
333   if(D.n_elem > 0)  { col_end_p1 += D.n_cols; out.cols(col_start, col_end_p1 - 1) = D; }
334   }
335 
336 
337 
338 template<typename T1, typename T2>
339 inline
340 void
apply(Cube<typename T1::elem_type> & out,const GlueCube<T1,T2,glue_join_slices> & X)341 glue_join_slices::apply(Cube<typename T1::elem_type>& out, const GlueCube<T1,T2,glue_join_slices>& X)
342   {
343   arma_extra_debug_sigprint();
344 
345   typedef typename T1::elem_type eT;
346 
347   const unwrap_cube<T1> A_tmp(X.A);
348   const unwrap_cube<T2> B_tmp(X.B);
349 
350   const Cube<eT>& A = A_tmp.M;
351   const Cube<eT>& B = B_tmp.M;
352 
353   if(A.n_elem == 0)  { out = B; return; }
354   if(B.n_elem == 0)  { out = A; return; }
355 
356   arma_debug_check( ( (A.n_rows != B.n_rows) || (A.n_cols != B.n_cols) ), "join_slices(): size of slices must be the same" );
357 
358   if( (&out != &A) && (&out != &B) )
359     {
360     out.set_size(A.n_rows, A.n_cols, A.n_slices + B.n_slices);
361 
362     out.slices(0,          A.n_slices-1  ) = A;
363     out.slices(A.n_slices, out.n_slices-1) = B;
364     }
365   else  // we have aliasing
366     {
367     Cube<eT> C(A.n_rows, A.n_cols, A.n_slices + B.n_slices, arma_nozeros_indicator());
368 
369     C.slices(0,          A.n_slices-1) = A;
370     C.slices(A.n_slices, C.n_slices-1) = B;
371 
372     out.steal_mem(C);
373     }
374 
375   }
376 
377 
378 
379 //! @}
380