1 // SPDX-License-Identifier: Apache-2.0
2 //
3 // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
4 // Copyright 2008-2016 National ICT Australia (NICTA)
5 //
6 // Licensed under the Apache License, Version 2.0 (the "License");
7 // you may not use this file except in compliance with the License.
8 // You may obtain a copy of the License at
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // ------------------------------------------------------------------------
17
18
19 //! \addtogroup op_trimat
20 //! @{
21
22
23
24 template<typename eT>
25 inline
26 void
fill_zeros(Mat<eT> & out,const bool upper)27 op_trimat::fill_zeros(Mat<eT>& out, const bool upper)
28 {
29 arma_extra_debug_sigprint();
30
31 const uword N = out.n_rows;
32
33 if(upper)
34 {
35 // upper triangular: set all elements below the diagonal to zero
36
37 for(uword i=0; i<N; ++i)
38 {
39 eT* data = out.colptr(i);
40
41 arrayops::fill_zeros( &data[i+1], (N-(i+1)) );
42 }
43 }
44 else
45 {
46 // lower triangular: set all elements above the diagonal to zero
47
48 for(uword i=1; i<N; ++i)
49 {
50 eT* data = out.colptr(i);
51
52 arrayops::fill_zeros( data, i );
53 }
54 }
55 }
56
57
58
59 template<typename T1>
60 inline
61 void
apply(Mat<typename T1::elem_type> & out,const Op<T1,op_trimat> & in)62 op_trimat::apply(Mat<typename T1::elem_type>& out, const Op<T1,op_trimat>& in)
63 {
64 arma_extra_debug_sigprint();
65
66 typedef typename T1::elem_type eT;
67
68 const bool upper = (in.aux_uword_a == 0);
69
70 // allow detection of in-place operation
71 if(is_Mat<T1>::value || (arma_config::openmp && Proxy<T1>::use_mp))
72 {
73 const unwrap<T1> U(in.m);
74
75 op_trimat::apply_unwrap(out, U.M, upper);
76 }
77 else
78 {
79 const Proxy<T1> P(in.m);
80
81 const bool is_alias = P.is_alias(out);
82
83 if(is_Mat<typename Proxy<T1>::stored_type>::value)
84 {
85 const quasi_unwrap<typename Proxy<T1>::stored_type> U(P.Q);
86
87 if(is_alias)
88 {
89 Mat<eT> tmp;
90
91 op_trimat::apply_unwrap(tmp, U.M, upper);
92
93 out.steal_mem(tmp);
94 }
95 else
96 {
97 op_trimat::apply_unwrap(out, U.M, upper);
98 }
99 }
100 else
101 {
102 if(is_alias)
103 {
104 Mat<eT> tmp;
105
106 op_trimat::apply_proxy(tmp, P, upper);
107
108 out.steal_mem(tmp);
109 }
110 else
111 {
112 op_trimat::apply_proxy(out, P, upper);
113 }
114 }
115 }
116 }
117
118
119
120 template<typename eT>
121 inline
122 void
apply_unwrap(Mat<eT> & out,const Mat<eT> & A,const bool upper)123 op_trimat::apply_unwrap(Mat<eT>& out, const Mat<eT>& A, const bool upper)
124 {
125 arma_extra_debug_sigprint();
126
127 arma_debug_check( (A.is_square() == false), "trimatu()/trimatl(): given matrix must be square sized" );
128
129 if(&out != &A)
130 {
131 out.copy_size(A);
132
133 const uword N = A.n_rows;
134
135 if(upper)
136 {
137 // upper triangular: copy the diagonal and the elements above the diagonal
138 for(uword i=0; i<N; ++i)
139 {
140 const eT* A_data = A.colptr(i);
141 eT* out_data = out.colptr(i);
142
143 arrayops::copy( out_data, A_data, i+1 );
144 }
145 }
146 else
147 {
148 // lower triangular: copy the diagonal and the elements below the diagonal
149 for(uword i=0; i<N; ++i)
150 {
151 const eT* A_data = A.colptr(i);
152 eT* out_data = out.colptr(i);
153
154 arrayops::copy( &out_data[i], &A_data[i], N-i );
155 }
156 }
157 }
158
159 op_trimat::fill_zeros(out, upper);
160 }
161
162
163
164 template<typename T1>
165 inline
166 void
apply_proxy(Mat<typename T1::elem_type> & out,const Proxy<T1> & P,const bool upper)167 op_trimat::apply_proxy(Mat<typename T1::elem_type>& out, const Proxy<T1>& P, const bool upper)
168 {
169 arma_extra_debug_sigprint();
170
171 arma_debug_check( (P.get_n_rows() != P.get_n_cols()), "trimatu()/trimatl(): given matrix must be square sized" );
172
173 const uword N = P.get_n_rows();
174
175 out.set_size(N,N);
176
177 if(upper)
178 {
179 for(uword j=0; j < N; ++j)
180 for(uword i=0; i < (j+1); ++i)
181 {
182 out.at(i,j) = P.at(i,j);
183 }
184 }
185 else
186 {
187 for(uword j=0; j<N; ++j)
188 for(uword i=j; i<N; ++i)
189 {
190 out.at(i,j) = P.at(i,j);
191 }
192 }
193
194 op_trimat::fill_zeros(out, upper);
195 }
196
197
198
199 //
200
201
202
203 template<typename T1>
204 inline
205 void
apply(Mat<typename T1::elem_type> & out,const Op<T1,op_trimatu_ext> & in)206 op_trimatu_ext::apply(Mat<typename T1::elem_type>& out, const Op<T1,op_trimatu_ext>& in)
207 {
208 arma_extra_debug_sigprint();
209
210 typedef typename T1::elem_type eT;
211
212 const unwrap<T1> tmp(in.m);
213 const Mat<eT>& A = tmp.M;
214
215 arma_debug_check( (A.is_square() == false), "trimatu(): given matrix must be square sized" );
216
217 const uword row_offset = in.aux_uword_a;
218 const uword col_offset = in.aux_uword_b;
219
220 const uword n_rows = A.n_rows;
221 const uword n_cols = A.n_cols;
222
223 arma_debug_check_bounds( ((row_offset > 0) && (row_offset >= n_rows)) || ((col_offset > 0) && (col_offset >= n_cols)), "trimatu(): requested diagonal is out of bounds" );
224
225 if(&out != &A)
226 {
227 out.copy_size(A);
228
229 const uword N = (std::min)(n_rows - row_offset, n_cols - col_offset);
230
231 for(uword i=0; i < n_cols; ++i)
232 {
233 const uword col = i + col_offset;
234
235 if(i < N)
236 {
237 const uword end_row = i + row_offset;
238
239 for(uword row=0; row <= end_row; ++row)
240 {
241 out.at(row,col) = A.at(row,col);
242 }
243 }
244 else
245 {
246 if(col < n_cols)
247 {
248 arrayops::copy(out.colptr(col), A.colptr(col), n_rows);
249 }
250 }
251 }
252 }
253
254 op_trimatu_ext::fill_zeros(out, row_offset, col_offset);
255 }
256
257
258
259 template<typename eT>
260 inline
261 void
fill_zeros(Mat<eT> & out,const uword row_offset,const uword col_offset)262 op_trimatu_ext::fill_zeros(Mat<eT>& out, const uword row_offset, const uword col_offset)
263 {
264 arma_extra_debug_sigprint();
265
266 const uword n_rows = out.n_rows;
267 const uword n_cols = out.n_cols;
268
269 const uword N = (std::min)(n_rows - row_offset, n_cols - col_offset);
270
271 for(uword col=0; col < col_offset; ++col)
272 {
273 arrayops::fill_zeros(out.colptr(col), n_rows);
274 }
275
276 for(uword i=0; i < N; ++i)
277 {
278 const uword start_row = i + row_offset + 1;
279 const uword col = i + col_offset;
280
281 for(uword row=start_row; row < n_rows; ++row)
282 {
283 out.at(row,col) = eT(0);
284 }
285 }
286 }
287
288
289
290 //
291
292
293
294 template<typename T1>
295 inline
296 void
apply(Mat<typename T1::elem_type> & out,const Op<T1,op_trimatl_ext> & in)297 op_trimatl_ext::apply(Mat<typename T1::elem_type>& out, const Op<T1,op_trimatl_ext>& in)
298 {
299 arma_extra_debug_sigprint();
300
301 typedef typename T1::elem_type eT;
302
303 const unwrap<T1> tmp(in.m);
304 const Mat<eT>& A = tmp.M;
305
306 arma_debug_check( (A.is_square() == false), "trimatl(): given matrix must be square sized" );
307
308 const uword row_offset = in.aux_uword_a;
309 const uword col_offset = in.aux_uword_b;
310
311 const uword n_rows = A.n_rows;
312 const uword n_cols = A.n_cols;
313
314 arma_debug_check_bounds( ((row_offset > 0) && (row_offset >= n_rows)) || ((col_offset > 0) && (col_offset >= n_cols)), "trimatl(): requested diagonal is out of bounds" );
315
316 if(&out != &A)
317 {
318 out.copy_size(A);
319
320 const uword N = (std::min)(n_rows - row_offset, n_cols - col_offset);
321
322 for(uword col=0; col < col_offset; ++col)
323 {
324 arrayops::copy( out.colptr(col), A.colptr(col), n_rows );
325 }
326
327 for(uword i=0; i<N; ++i)
328 {
329 const uword start_row = i + row_offset;
330 const uword col = i + col_offset;
331
332 for(uword row=start_row; row < n_rows; ++row)
333 {
334 out.at(row,col) = A.at(row,col);
335 }
336 }
337 }
338
339 op_trimatl_ext::fill_zeros(out, row_offset, col_offset);
340 }
341
342
343
344 template<typename eT>
345 inline
346 void
fill_zeros(Mat<eT> & out,const uword row_offset,const uword col_offset)347 op_trimatl_ext::fill_zeros(Mat<eT>& out, const uword row_offset, const uword col_offset)
348 {
349 arma_extra_debug_sigprint();
350
351 const uword n_rows = out.n_rows;
352 const uword n_cols = out.n_cols;
353
354 const uword N = (std::min)(n_rows - row_offset, n_cols - col_offset);
355
356 for(uword i=0; i < n_cols; ++i)
357 {
358 const uword col = i + col_offset;
359
360 if(i < N)
361 {
362 const uword end_row = i + row_offset;
363
364 for(uword row=0; row < end_row; ++row)
365 {
366 out.at(row,col) = eT(0);
367 }
368 }
369 else
370 {
371 if(col < n_cols)
372 {
373 arrayops::fill_zeros(out.colptr(col), n_rows);
374 }
375 }
376 }
377 }
378
379
380
381 //! @}
382