1 // SPDX-License-Identifier: Apache-2.0
2 //
3 // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
4 // Copyright 2008-2016 National ICT Australia (NICTA)
5 //
6 // Licensed under the Apache License, Version 2.0 (the "License");
7 // you may not use this file except in compliance with the License.
8 // You may obtain a copy of the License at
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // ------------------------------------------------------------------------
17
18
19 //! \addtogroup band_helper
20 //! @{
21
22
23 namespace band_helper
24 {
25
26
27
28 template<typename eT>
29 inline
30 bool
is_band(uword & out_KL,uword & out_KU,const Mat<eT> & A,const uword N_min)31 is_band(uword& out_KL, uword& out_KU, const Mat<eT>& A, const uword N_min)
32 {
33 arma_extra_debug_sigprint();
34
35 // NOTE: assuming that A has a square size
36 // NOTE: assuming that N_min is >= 4
37
38 const uword N = A.n_rows;
39
40 if(N < N_min) { return false; }
41
42 // first, quickly check bottom-left and top-right corners
43
44 const eT eT_zero = eT(0);
45
46 const eT* A_col0 = A.memptr();
47 const eT* A_col1 = A_col0 + N;
48
49 if( (A_col0[N-2] != eT_zero) || (A_col0[N-1] != eT_zero) || (A_col1[N-2] != eT_zero) || (A_col1[N-1] != eT_zero) ) { return false; }
50
51 const eT* A_colNm2 = A.colptr(N-2);
52 const eT* A_colNm1 = A_colNm2 + N;
53
54 if( (A_colNm2[0] != eT_zero) || (A_colNm2[1] != eT_zero) || (A_colNm1[0] != eT_zero) || (A_colNm1[1] != eT_zero) ) { return false; }
55
56 // if we reached this point, go through the entire matrix to work out number of subdiagonals and superdiagonals
57
58 const uword n_nonzero_threshold = (N*N)/4; // empirically determined
59
60 uword KL = 0; // number of subdiagonals
61 uword KU = 0; // number of superdiagonals
62
63 const eT* A_colptr = A.memptr();
64
65 for(uword col=0; col < N; ++col)
66 {
67 uword first_nonzero_row = col;
68 uword last_nonzero_row = col;
69
70 for(uword row=0; row < col; ++row)
71 {
72 if( A_colptr[row] != eT_zero ) { first_nonzero_row = row; break; }
73 }
74
75 for(uword row=(col+1); row < N; ++row)
76 {
77 last_nonzero_row = (A_colptr[row] != eT_zero) ? row : last_nonzero_row;
78 }
79
80 const uword L_count = last_nonzero_row - col;
81 const uword U_count = col - first_nonzero_row;
82
83 if( (L_count > KL) || (U_count > KU) )
84 {
85 KL = (std::max)(KL, L_count);
86 KU = (std::max)(KU, U_count);
87
88 const uword n_nonzero = N*(KL+KU+1) - (KL*(KL+1) + KU*(KU+1))/2;
89
90 // return as soon as we know that it's not worth analysing the matrix any further
91
92 if(n_nonzero > n_nonzero_threshold) { return false; }
93 }
94
95 A_colptr += N;
96 }
97
98 out_KL = KL;
99 out_KU = KU;
100
101 return true;
102 }
103
104
105
106 template<typename eT>
107 inline
108 bool
is_band_lower(uword & out_KD,const Mat<eT> & A,const uword N_min)109 is_band_lower(uword& out_KD, const Mat<eT>& A, const uword N_min)
110 {
111 arma_extra_debug_sigprint();
112
113 // NOTE: assuming that A has a square size
114 // NOTE: assuming that N_min is >= 4
115
116 const uword N = A.n_rows;
117
118 if(N < N_min) { return false; }
119
120 // first, quickly check bottom-left corner
121
122 const eT eT_zero = eT(0);
123
124 const eT* A_col0 = A.memptr();
125 const eT* A_col1 = A_col0 + N;
126
127 if( (A_col0[N-2] != eT_zero) || (A_col0[N-1] != eT_zero) || (A_col1[N-2] != eT_zero) || (A_col1[N-1] != eT_zero) ) { return false; }
128
129 // if we reached this point, go through the bottom triangle to work out number of subdiagonals
130
131 const uword n_nonzero_threshold = ( N*N - (N*(N-1))/2 ) / 4; // empirically determined
132
133 uword KL = 0; // number of subdiagonals
134
135 const eT* A_colptr = A.memptr();
136
137 for(uword col=0; col < N; ++col)
138 {
139 uword last_nonzero_row = col;
140
141 for(uword row=(col+1); row < N; ++row)
142 {
143 last_nonzero_row = (A_colptr[row] != eT_zero) ? row : last_nonzero_row;
144 }
145
146 const uword L_count = last_nonzero_row - col;
147
148 if(L_count > KL)
149 {
150 KL = L_count;
151
152 const uword n_nonzero = N*(KL+1) - (KL*(KL+1))/2;
153
154 // return as soon as we know that it's not worth analysing the matrix any further
155
156 if(n_nonzero > n_nonzero_threshold) { return false; }
157 }
158
159 A_colptr += N;
160 }
161
162 out_KD = KL;
163
164 return true;
165 }
166
167
168
169 template<typename eT>
170 inline
171 bool
is_band_upper(uword & out_KD,const Mat<eT> & A,const uword N_min)172 is_band_upper(uword& out_KD, const Mat<eT>& A, const uword N_min)
173 {
174 arma_extra_debug_sigprint();
175
176 // NOTE: assuming that A has a square size
177 // NOTE: assuming that N_min is >= 4
178
179 const uword N = A.n_rows;
180
181 if(N < N_min) { return false; }
182
183 // first, quickly check top-right corner
184
185 const eT eT_zero = eT(0);
186
187 const eT* A_colNm2 = A.colptr(N-2);
188 const eT* A_colNm1 = A_colNm2 + N;
189
190 if( (A_colNm2[0] != eT_zero) || (A_colNm2[1] != eT_zero) || (A_colNm1[0] != eT_zero) || (A_colNm1[1] != eT_zero) ) { return false; }
191
192 // if we reached this point, go through the entire matrix to work out number of superdiagonals
193
194 const uword n_nonzero_threshold = ( N*N - (N*(N-1))/2 ) / 4; // empirically determined
195
196 uword KU = 0; // number of superdiagonals
197
198 const eT* A_colptr = A.memptr();
199
200 for(uword col=0; col < N; ++col)
201 {
202 uword first_nonzero_row = col;
203
204 for(uword row=0; row < col; ++row)
205 {
206 if( A_colptr[row] != eT_zero ) { first_nonzero_row = row; break; }
207 }
208
209 const uword U_count = col - first_nonzero_row;
210
211 if(U_count > KU)
212 {
213 KU = U_count;
214
215 const uword n_nonzero = N*(KU+1) - (KU*(KU+1))/2;
216
217 // return as soon as we know that it's not worth analysing the matrix any further
218
219 if(n_nonzero > n_nonzero_threshold) { return false; }
220 }
221
222 A_colptr += N;
223 }
224
225 out_KD = KU;
226
227 return true;
228 }
229
230
231
232 template<typename eT>
233 inline
234 void
compress(Mat<eT> & AB,const Mat<eT> & A,const uword KL,const uword KU,const bool use_offset)235 compress(Mat<eT>& AB, const Mat<eT>& A, const uword KL, const uword KU, const bool use_offset)
236 {
237 arma_extra_debug_sigprint();
238
239 // NOTE: assuming that A has a square size
240
241 // band matrix storage format
242 // http://www.netlib.org/lapack/lug/node124.html
243
244 // for ?gbsv, matrix AB size: 2*KL+KU+1 x N; band representation of A stored in rows KL+1 to 2*KL+KU+1 (note: fortran counts from 1)
245 // for ?gbsvx, matrix AB size: KL+KU+1 x N; band representaiton of A stored in rows 1 to KL+KU+1 (note: fortran counts from 1)
246 //
247 // the +1 in the above formulas is to take into account the main diagonal
248
249 const uword AB_n_rows = (use_offset) ? uword(2*KL + KU + 1) : uword(KL + KU + 1);
250 const uword N = A.n_rows;
251
252 AB.set_size(AB_n_rows, N);
253
254 if(A.is_empty()) { AB.zeros(); return; }
255
256 if(AB_n_rows == uword(1))
257 {
258 eT* AB_mem = AB.memptr();
259
260 for(uword i=0; i<N; ++i) { AB_mem[i] = A.at(i,i); }
261 }
262 else
263 {
264 AB.zeros(); // paranoia
265
266 for(uword j=0; j < N; ++j)
267 {
268 const uword A_row_start = (j > KU) ? uword(j - KU) : uword(0);
269 const uword A_row_endp1 = (std::min)(N, j+KL+1);
270
271 const uword length = A_row_endp1 - A_row_start;
272
273 const uword AB_row_start = (KU > j) ? (KU - j) : uword(0);
274
275 const eT* A_colptr = A.colptr(j) + A_row_start;
276 eT* AB_colptr = AB.colptr(j) + AB_row_start + ( (use_offset) ? KL : uword(0) );
277
278 arrayops::copy( AB_colptr, A_colptr, length );
279 }
280 }
281 }
282
283
284
285 template<typename eT>
286 inline
287 void
uncompress(Mat<eT> & A,const Mat<eT> & AB,const uword KL,const uword KU,const bool use_offset)288 uncompress(Mat<eT>& A, const Mat<eT>& AB, const uword KL, const uword KU, const bool use_offset)
289 {
290 arma_extra_debug_sigprint();
291
292 const uword AB_n_rows = AB.n_rows;
293 const uword N = AB.n_cols;
294
295 arma_debug_check( (AB_n_rows != ((use_offset) ? uword(2*KL + KU + 1) : uword(KL + KU + 1))), "band_helper::uncompress(): detected inconsistency" );
296
297 A.zeros(N,N); // assuming there is no aliasing between A and AB
298
299 if(AB_n_rows == uword(1))
300 {
301 const eT* AB_mem = AB.memptr();
302
303 for(uword i=0; i<N; ++i) { A.at(i,i) = AB_mem[i]; }
304 }
305 else
306 {
307 for(uword j=0; j < N; ++j)
308 {
309 const uword A_row_start = (j > KU) ? uword(j - KU) : uword(0);
310 const uword A_row_endp1 = (std::min)(N, j+KL+1);
311
312 const uword length = A_row_endp1 - A_row_start;
313
314 const uword AB_row_start = (KU > j) ? (KU - j) : uword(0);
315
316 const eT* AB_colptr = AB.colptr(j) + AB_row_start + ( (use_offset) ? KL : uword(0) );
317 eT* A_colptr = A.colptr(j) + A_row_start;
318
319 arrayops::copy( A_colptr, AB_colptr, length );
320 }
321 }
322 }
323
324
325
326 template<typename eT>
327 inline
328 void
extract_tridiag(Mat<eT> & out,const Mat<eT> & A)329 extract_tridiag(Mat<eT>& out, const Mat<eT>& A)
330 {
331 arma_extra_debug_sigprint();
332
333 // NOTE: assuming that A has a square size and is at least 2x2
334
335 const uword N = A.n_rows;
336
337 out.set_size(N, 3); // assuming there is no aliasing between 'out' and 'A'
338
339 if(N < 2) { return; }
340
341 eT* DL = out.colptr(0);
342 eT* DD = out.colptr(1);
343 eT* DU = out.colptr(2);
344
345 DD[0] = A[0];
346 DL[0] = A[1];
347
348 const uword Nm1 = N-1;
349 const uword Nm2 = N-2;
350
351 for(uword i=0; i < Nm2; ++i)
352 {
353 const uword ip1 = i+1;
354
355 const eT* data = &(A.at(i, ip1));
356
357 const eT tmp0 = data[0];
358 const eT tmp1 = data[1];
359 const eT tmp2 = data[2];
360
361 DL[ip1] = tmp2;
362 DD[ip1] = tmp1;
363 DU[i ] = tmp0;
364 }
365
366 const eT* data = &(A.at(Nm2, Nm1));
367
368 DL[Nm1] = 0;
369 DU[Nm2] = data[0];
370 DU[Nm1] = 0;
371 DD[Nm1] = data[1];
372 }
373
374
375
376 } // end of namespace band_helper
377
378
379 //! @}
380