1 /*
2 
3     Copyright (C) 2014, The University of Texas at Austin
4 
5     This file is part of libflame and is available under the 3-Clause
6     BSD license, which can be found in the LICENSE file at the top-level
7     directory, or at http://opensource.org/licenses/BSD-3-Clause
8 
9 */
10 
11 #include "FLAME.h"
12 
FLA_Bsvd_sinval_v_opt_var1(FLA_Obj tol,FLA_Obj thresh,FLA_Obj G,FLA_Obj H,FLA_Obj d,FLA_Obj e,FLA_Obj k)13 FLA_Error FLA_Bsvd_sinval_v_opt_var1( FLA_Obj tol, FLA_Obj thresh,
14                                       FLA_Obj G, FLA_Obj H,
15                                       FLA_Obj d, FLA_Obj e,
16                                       FLA_Obj k )
17 {
18     FLA_Datatype datatype;
19     int          m_A, n_GH;
20     int          rs_G, cs_G;
21     int          rs_H, cs_H;
22     int          inc_d;
23     int          inc_e;
24 
25     datatype = FLA_Obj_datatype( d );
26 
27     m_A      = FLA_Obj_vector_dim( d );
28     n_GH     = FLA_Obj_width( G );
29 
30     rs_G     = FLA_Obj_row_stride( G );
31     cs_G     = FLA_Obj_col_stride( G );
32 
33     rs_H     = FLA_Obj_row_stride( H );
34     cs_H     = FLA_Obj_col_stride( H );
35 
36     inc_d    = FLA_Obj_vector_inc( d );
37     inc_e    = FLA_Obj_vector_inc( e );
38 
39 
40     switch ( datatype )
41     {
42     case FLA_FLOAT:
43     {
44         float*    buff_tol    = FLA_FLOAT_PTR( tol );
45         float*    buff_thresh = FLA_FLOAT_PTR( thresh );
46         scomplex* buff_G      = FLA_COMPLEX_PTR( G );
47         scomplex* buff_H      = FLA_COMPLEX_PTR( H );
48         float*    buff_d      = FLA_FLOAT_PTR( d );
49         float*    buff_e      = FLA_FLOAT_PTR( e );
50         int*      buff_k      = FLA_INT_PTR( k );
51 
52         FLA_Bsvd_sinval_v_ops_var1( m_A,
53                                     n_GH,
54                                     9,
55                                     *buff_tol,
56                                     *buff_thresh,
57                                     buff_G, rs_G, cs_G,
58                                     buff_H, rs_H, cs_H,
59                                     buff_d, inc_d,
60                                     buff_e, inc_e,
61                                     buff_k );
62 
63         break;
64     }
65 
66     case FLA_DOUBLE:
67     {
68         double*   buff_tol    = FLA_DOUBLE_PTR( tol );
69         double*   buff_thresh = FLA_DOUBLE_PTR( thresh );
70         dcomplex* buff_G      = FLA_DOUBLE_COMPLEX_PTR( G );
71         dcomplex* buff_H      = FLA_DOUBLE_COMPLEX_PTR( H );
72         double*   buff_d      = FLA_DOUBLE_PTR( d );
73         double*   buff_e      = FLA_DOUBLE_PTR( e );
74         int*      buff_k      = FLA_INT_PTR( k );
75 
76         FLA_Bsvd_sinval_v_opd_var1( m_A,
77                                     n_GH,
78                                     9,
79                                     *buff_tol,
80                                     *buff_thresh,
81                                     buff_G, rs_G, cs_G,
82                                     buff_H, rs_H, cs_H,
83                                     buff_d, inc_d,
84                                     buff_e, inc_e,
85                                     buff_k );
86 
87         break;
88     }
89     }
90 
91     return FLA_SUCCESS;
92 }
93 
94 
95 
FLA_Bsvd_sinval_v_ops_var1(int m_A,int n_GH,int n_iter_allowed,float tol,float thresh,scomplex * buff_G,int rs_G,int cs_G,scomplex * buff_H,int rs_H,int cs_H,float * buff_d,int inc_d,float * buff_e,int inc_e,int * n_iter)96 FLA_Error FLA_Bsvd_sinval_v_ops_var1( int       m_A,
97                                       int       n_GH,
98                                       int       n_iter_allowed,
99                                       float     tol,
100                                       float     thresh,
101                                       scomplex* buff_G, int rs_G, int cs_G,
102                                       scomplex* buff_H, int rs_H, int cs_H,
103                                       float*    buff_d, int inc_d,
104                                       float*    buff_e, int inc_e,
105                                       int*      n_iter )
106 {
107     FLA_Error r_val;
108     float     one = bl1_s1();
109     //float*   d_first;
110     //float*   d_last_m1;
111     float*    e_last;
112     //float*   d_last;
113     float     smax;
114     float     smin;
115     float     sminl;
116     float     shift;
117     int       k;
118 
119     // Initialize pointers to some diagonal and superdiagonal elements
120     // that we will refer to later.
121     e_last    = buff_e + (m_A-2)*inc_e;
122     //d_last_m1 = buff_d + (m_A-2)*inc_d;
123     //d_last    = buff_d + (m_A-1)*inc_d;
124     //d_first   = buff_d + (0    )*inc_d;
125 
126     // Find the largest element of the diagonal or superdiagonal.
127     // This is used later when checking the shift.
128     FLA_Bsvd_find_max_min_ops( m_A,
129                                buff_d, inc_d,
130                                buff_e, inc_e,
131                                &smax,
132                                &smin );
133 
134     // Perform some iterations.
135     for ( k = 0; k < n_iter_allowed; ++k )
136     {
137         scomplex* g1 = buff_G + (k  )*cs_G;
138         scomplex* h1 = buff_H + (k  )*cs_H;
139 
140         /*------------------------------------------------------------*/
141 
142         // Before we perform any rotations, check for pre-existing deflation.
143         r_val = FLA_Bsvd_find_converged_ops( m_A,
144                                              tol,
145                                              buff_d, inc_d,
146                                              buff_e, inc_e,
147                                              &sminl );
148 
149         // If r_val is positive, then deflation was found.
150         if ( 0 <= r_val )
151         {
152             // Set the off-diagonal element to zero.
153             buff_e[ (r_val)*inc_e ] = 0.0F;
154 
155             *n_iter = k;
156             return r_val;
157         }
158 
159 
160         // Compute a shift with the last 2x2 matrix.
161         FLA_Bsvd_compute_shift_ops( m_A,
162                                     tol,
163                                     sminl,
164                                     smax,
165                                     buff_d, inc_d,
166                                     buff_e, inc_e,
167                                     &shift );
168 
169         // Perform a Francis step.
170         r_val = FLA_Bsvd_francis_v_ops_var1( m_A,
171                                              shift,
172                                              g1,     rs_G,
173                                              h1,     rs_H,
174                                              buff_d, inc_d,
175                                              buff_e, inc_e );
176 
177         // Check for convergence using thresh.
178         if ( MAC_Bsvd_sinval_is_converged_ops( thresh, one, *e_last ) )
179         {
180             *e_last = 0.0F;
181             *n_iter = k + 1;
182             return m_A - 1;
183         }
184 
185         /*------------------------------------------------------------*/
186     }
187 
188     *n_iter = n_iter_allowed;
189     return FLA_SUCCESS;
190 }
191 
192 //#define PRINTF
193 
FLA_Bsvd_sinval_v_opd_var1(int m_A,int n_GH,int n_iter_allowed,double tol,double thresh,dcomplex * buff_G,int rs_G,int cs_G,dcomplex * buff_H,int rs_H,int cs_H,double * buff_d,int inc_d,double * buff_e,int inc_e,int * n_iter)194 FLA_Error FLA_Bsvd_sinval_v_opd_var1( int       m_A,
195                                       int       n_GH,
196                                       int       n_iter_allowed,
197                                       double    tol,
198                                       double    thresh,
199                                       dcomplex* buff_G, int rs_G, int cs_G,
200                                       dcomplex* buff_H, int rs_H, int cs_H,
201                                       double*   buff_d, int inc_d,
202                                       double*   buff_e, int inc_e,
203                                       int*      n_iter )
204 {
205     FLA_Error r_val;
206     double    one = bl1_d1();
207     //double*   d_first;
208     //double*   d_last_m1;
209     double*   e_last;
210     //double*   d_last;
211     double    smax;
212     double    smin;
213     double    sminl;
214     double    shift;
215     int       k;
216 
217     // Initialize pointers to some diagonal and superdiagonal elements
218     // that we will refer to later.
219     e_last    = buff_e + (m_A-2)*inc_e;
220     //d_last_m1 = buff_d + (m_A-2)*inc_d;
221     //d_last    = buff_d + (m_A-1)*inc_d;
222     //d_first   = buff_d + (0    )*inc_d;
223 
224     // Find the largest element of the diagonal or superdiagonal.
225     // This is used later when checking the shift.
226     FLA_Bsvd_find_max_min_opd( m_A,
227                                buff_d, inc_d,
228                                buff_e, inc_e,
229                                &smax,
230                                &smin );
231 
232     // Perform some iterations.
233     for ( k = 0; k < n_iter_allowed; ++k )
234     {
235         dcomplex* g1 = buff_G + (k  )*cs_G;
236         dcomplex* h1 = buff_H + (k  )*cs_H;
237 
238         /*------------------------------------------------------------*/
239 
240         // Before we perform any rotations, check for pre-existing deflation.
241         r_val = FLA_Bsvd_find_converged_opd( m_A,
242                                              tol,
243                                              buff_d, inc_d,
244                                              buff_e, inc_e,
245                                              &sminl );
246 
247         // If r_val is positive, then deflation was found.
248         if ( 0 <= r_val )
249         {
250 #ifdef PRINTF
251             printf( "FLA_Bsvd_sinval_v_opt_var1: Deflation detected in col %d, sval %d\n", r_val, m_A - 1 );
252             printf( "FLA_Bsvd_sinval_v_opt_var1: alpha11 alpha12 = %23.19e %23.19e\n", buff_d[r_val*inc_d], buff_e[r_val*inc_e] );
253             printf( "FLA_Bsvd_sinval_v_opt_var1:         alpha22 =         %43.19e\n", buff_d[(r_val+1)*inc_d] );
254 #endif
255 
256             // Set the off-diagonal element to zero.
257             buff_e[ (r_val)*inc_e ] = 0.0;
258 
259             *n_iter = k;
260             return r_val;
261         }
262 
263 
264         // Compute a shift with the last 2x2 matrix.
265         FLA_Bsvd_compute_shift_opd( m_A,
266                                     tol,
267                                     sminl,
268                                     smax,
269                                     buff_d, inc_d,
270                                     buff_e, inc_e,
271                                     &shift );
272 
273         // Perform a Francis step.
274         r_val = FLA_Bsvd_francis_v_opd_var1( m_A,
275                                              shift,
276                                              g1,     rs_G,
277                                              h1,     rs_H,
278                                              buff_d, inc_d,
279                                              buff_e, inc_e );
280 
281         // Check for convergence using thresh.
282         if ( MAC_Bsvd_sinval_is_converged_opd( thresh, one, *e_last ) )
283         {
284             *e_last = 0.0;
285             *n_iter = k + 1;
286             return m_A - 1;
287         }
288 
289         /*------------------------------------------------------------*/
290     }
291 
292     *n_iter = n_iter_allowed;
293     return FLA_FAILURE;
294 }
295 
296