1 /*
2 
3     Copyright (C) 2014, The University of Texas at Austin
4 
5     This file is part of libflame and is available under the 3-Clause
6     BSD license, which can be found in the LICENSE file at the top-level
7     directory, or at http://opensource.org/licenses/BSD-3-Clause
8 
9 */
10 
11 #include "FLAME.h"
12 
FLA_Bsvd_iteracc_v_ops_var1(int m_A,int n_GH,int ijTL,float tol,float thresh,float * buff_d,int inc_d,float * buff_e,int inc_e,scomplex * buff_G,int rs_G,int cs_G,scomplex * buff_H,int rs_H,int cs_H,int * n_iter_perf)13 FLA_Error FLA_Bsvd_iteracc_v_ops_var1( int       m_A,
14                                        int       n_GH,
15                                        int       ijTL,
16                                        float     tol,
17                                        float     thresh,
18                                        float*    buff_d, int inc_d,
19                                        float*    buff_e, int inc_e,
20                                        scomplex* buff_G, int rs_G, int cs_G,
21                                        scomplex* buff_H, int rs_H, int cs_H,
22                                        int*      n_iter_perf )
23 {
24     FLA_Error r_val;
25     int       i, k;
26     int       k_iter       = 0;
27     int       n_deflations = 0;
28 
29     // Iterate from back to front until all that is left is a 2x2.
30     for ( i = m_A - 1; i > 1; --i )
31     {
32         scomplex* G1     = buff_G + (k_iter)*cs_G;
33         scomplex* H1     = buff_H + (k_iter)*cs_H;
34         int       m_ATL  = i + 1;
35         int       k_left = n_GH - k_iter;
36 
37         /*------------------------------------------------------------*/
38 
39         // Find a singular value of ATL submatrix.
40         r_val = FLA_Bsvd_sinval_v_ops_var1( m_ATL,
41                                             n_GH,
42                                             k_left,
43                                             tol,
44                                             thresh,
45                                             G1,     rs_G, cs_G,
46                                             H1,     rs_H, cs_H,
47                                             buff_d, inc_d,
48                                             buff_e, inc_e,
49                                             &k );
50 
51         // Update local counters according to the results of the singular
52         // value search.
53         k_iter       += k;
54         n_deflations += 1;
55 
56         if ( r_val == FLA_FAILURE )
57         {
58             *n_iter_perf = k_iter;
59             return n_deflations;
60         }
61 
62         // If the most recent singular value search put us at our
63         // limit for accumulated Givens rotation sets, return.
64         if ( k_iter == n_GH )
65         {
66             *n_iter_perf = k_iter;
67             return n_deflations;
68         }
69 
70         // If r_val != i, then a split occurred somewhere within ATL.
71         // Therefore, we must recurse with subproblems.
72         if ( r_val != i )
73         {
74             int       m_TLr = r_val + 1;
75             int       m_BRr = m_ATL - m_TLr;
76             int       ijTLr = 0;
77             int       ijBRr = m_TLr;
78             int       n_GHr = n_GH - k_iter;
79             float*    dTL   = buff_d + (0    )*inc_d;
80             float*    eTL   = buff_e + (0    )*inc_e;
81             scomplex* GT    = buff_G + (0    )*rs_G + (k_iter)*cs_G;
82             scomplex* HT    = buff_H + (0    )*rs_H + (k_iter)*cs_H;
83             float*    dBR   = buff_d + (ijBRr)*inc_d;
84             float*    eBR   = buff_e + (ijBRr)*inc_e;
85             scomplex* GB    = buff_G + (ijBRr)*rs_G + (k_iter)*cs_G;
86             scomplex* HB    = buff_H + (ijBRr)*rs_H + (k_iter)*cs_H;
87 
88             int       n_deflationsTL;
89             int       n_deflationsBR;
90             int       n_iter_perfTL;
91             int       n_iter_perfBR;
92 
93             n_deflationsTL = FLA_Bsvd_iteracc_v_ops_var1( m_TLr,
94                                                           n_GHr,
95                                                           ijTL + ijTLr,
96                                                           tol,
97                                                           thresh,
98                                                           dTL, inc_d,
99                                                           eTL, inc_e,
100                                                           GT,  rs_G, cs_G,
101                                                           HT,  rs_H, cs_H,
102                                                           &n_iter_perfTL );
103             n_deflationsBR = FLA_Bsvd_iteracc_v_ops_var1( m_BRr,
104                                                           n_GHr,
105                                                           ijTL + ijBRr,
106                                                           tol,
107                                                           thresh,
108                                                           dBR, inc_d,
109                                                           eBR, inc_e,
110                                                           GB,  rs_G, cs_G,
111                                                           HB,  rs_H, cs_H,
112                                                           &n_iter_perfBR );
113 
114             *n_iter_perf = k_iter + max( n_iter_perfTL, n_iter_perfBR );
115 
116             return n_deflations + n_deflationsTL + n_deflationsBR;
117         }
118 
119         /*------------------------------------------------------------*/
120     }
121 
122     // Skip 1x1 matrices (and submatrices) entirely.
123     if ( m_A > 1 )
124     {
125         scomplex* g1 = buff_G + (k_iter)*cs_G;
126         scomplex* h1 = buff_H + (k_iter)*cs_H;
127 
128         float*    alpha11 = buff_d + (0  )*inc_d;
129         float*    alpha12 = buff_e + (0  )*inc_e;
130         float*    alpha22 = buff_d + (1  )*inc_d;
131 
132         float     smin;
133         float     smax;
134 
135         float     gammaL;
136         float     sigmaL;
137         float     gammaR;
138         float     sigmaR;
139 
140         // Find the singular value decomposition of the remaining (or only)
141         // 2x2 submatrix.
142         FLA_Svv_2x2_ops( alpha11,
143                          alpha12,
144                          alpha22,
145                          &smin,
146                          &smax,
147                          &gammaL,
148                          &sigmaL,
149                          &gammaR,
150                          &sigmaR );
151 
152         *alpha11 = smax;
153         *alpha22 = smin;
154 
155         // Zero out the remaining diagonal.
156         *alpha12 = 0.0F;
157 
158         // Store the rotations.
159         g1[0].real = gammaL;
160         g1[0].imag = sigmaL;
161         h1[0].real = gammaR;
162         h1[0].imag = sigmaR;
163 
164         // Update the local counters.
165         k_iter       += 1;
166         n_deflations += 1;
167 
168     }
169 
170     *n_iter_perf = k_iter;
171     return n_deflations;
172 }
173 
174 //#define PRINTF
175 
FLA_Bsvd_iteracc_v_opd_var1(int m_A,int n_GH,int ijTL,double tol,double thresh,double * buff_d,int inc_d,double * buff_e,int inc_e,dcomplex * buff_G,int rs_G,int cs_G,dcomplex * buff_H,int rs_H,int cs_H,int * n_iter_perf)176 FLA_Error FLA_Bsvd_iteracc_v_opd_var1( int       m_A,
177                                        int       n_GH,
178                                        int       ijTL,
179                                        double    tol,
180                                        double    thresh,
181                                        double*   buff_d, int inc_d,
182                                        double*   buff_e, int inc_e,
183                                        dcomplex* buff_G, int rs_G, int cs_G,
184                                        dcomplex* buff_H, int rs_H, int cs_H,
185                                        int*      n_iter_perf )
186 {
187     FLA_Error r_val;
188     int       i, k;
189     int       k_iter       = 0;
190     int       n_deflations = 0;
191 
192     // Iterate from back to front until all that is left is a 2x2.
193     for ( i = m_A - 1; i > 1; --i )
194     {
195         dcomplex* G1     = buff_G + (k_iter)*cs_G;
196         dcomplex* H1     = buff_H + (k_iter)*cs_H;
197         int       m_ATL  = i + 1;
198         int       k_left = n_GH - k_iter;
199 
200         /*------------------------------------------------------------*/
201 
202         // Find a singular value of ATL submatrix.
203         r_val = FLA_Bsvd_sinval_v_opd_var1( m_ATL,
204                                             n_GH,
205                                             k_left,
206                                             tol,
207                                             thresh,
208                                             G1,     rs_G, cs_G,
209                                             H1,     rs_H, cs_H,
210                                             buff_d, inc_d,
211                                             buff_e, inc_e,
212                                             &k );
213 
214         // Update local counters according to the results of the singular
215         // value search.
216         k_iter       += k;
217         n_deflations += 1;
218 
219         if ( r_val == FLA_FAILURE )
220         {
221 #ifdef PRINTF
222             printf( "FLA_Bsvd_iteracc_v_opd_var1: failed to converge (m_A11 = %d) after %2d iters k_total=%d/%d\n", i, k, k_iter, n_G );
223 #endif
224             *n_iter_perf = k_iter;
225             return n_deflations;
226         }
227 
228 #ifdef PRINTF
229         if ( r_val == i )
230             printf( "FLA_Bsvd_iteracc_v_opd_var1: found sv %22.15e in col %3d (n=%d) after %2d it  k_tot=%d/%d\n", buff_d[ r_val*inc_d ], ijTL+r_val, m_ATL, k, k_iter, n_GH );
231         else
232             printf( "FLA_Bsvd_iteracc_v_opd_var1: split occurred in col %3d. (n=%d) after %2d it  k_tot=%d/%d\n", r_val, m_ATL, k, k_iter, n_GH );
233 #endif
234 
235         // If the most recent singular value search put us at our
236         // limit for accumulated Givens rotation sets, return.
237         if ( k_iter == n_GH )
238         {
239             *n_iter_perf = k_iter;
240             return n_deflations;
241         }
242 
243         // If r_val != i, then a split occurred somewhere within ATL.
244         // Therefore, we must recurse with subproblems.
245         if ( r_val != i )
246         {
247             int       m_TLr = r_val + 1;
248             int       m_BRr = m_ATL - m_TLr;
249             int       ijTLr = 0;
250             int       ijBRr = m_TLr;
251             int       n_GHr = n_GH - k_iter;
252             double*   dTL   = buff_d + (0    )*inc_d;
253             double*   eTL   = buff_e + (0    )*inc_e;
254             dcomplex* GT    = buff_G + (0    )*rs_G + (k_iter)*cs_G;
255             dcomplex* HT    = buff_H + (0    )*rs_H + (k_iter)*cs_H;
256             double*   dBR   = buff_d + (ijBRr)*inc_d;
257             double*   eBR   = buff_e + (ijBRr)*inc_e;
258             dcomplex* GB    = buff_G + (ijBRr)*rs_G + (k_iter)*cs_G;
259             dcomplex* HB    = buff_H + (ijBRr)*rs_H + (k_iter)*cs_H;
260 
261             int       n_deflationsTL;
262             int       n_deflationsBR;
263             int       n_iter_perfTL;
264             int       n_iter_perfBR;
265 
266 #ifdef PRINTF
267             printf( "FLA_Bsvd_iteracc_v_opd_var1: Deflation occurred in col %d\n", r_val );
268             printf( "FLA_Bsvd_iteracc_v_opd_var1: alpha11 alpha12 = %22.15e %22.15e\n", buff_d[r_val*inc_d], buff_e[(r_val)*inc_e] );
269             printf( "FLA_Bsvd_iteracc_v_opd_var1:         alpha22 =         %37.15e\n", buff_d[(r_val+1)*inc_d] );
270 
271             printf( "FLA_Bsvd_iteracc_v_opd_var1: recursing: ijTLr m_TLr: %d %d\n", ijTLr, m_TLr );
272             printf( "FLA_Bsvd_iteracc_v_opd_var1:            GB(0,0) i,j: %d %d\n", ijTL + m_TLr+1, k_iter );
273 #endif
274             n_deflationsTL = FLA_Bsvd_iteracc_v_opd_var1( m_TLr,
275                              n_GHr,
276                              ijTL + ijTLr,
277                              tol,
278                              thresh,
279                              dTL, inc_d,
280                              eTL, inc_e,
281                              GT,  rs_G, cs_G,
282                              HT,  rs_H, cs_H,
283                              &n_iter_perfTL );
284 #ifdef PRINTF
285             printf( "FLA_Bsvd_iteracc_v_opd_var1: returning: ijTLr m_TLr: %d %d\n", ijTLr, m_TLr );
286             printf( "FLA_Bsvd_iteracc_v_opd_var1: recursing: ijBRr m_BRr: %d %d\n", ijBRr, m_BRr );
287             printf( "FLA_Bsvd_iteracc_v_opd_var1:            GB(0,0) i,j: %d %d\n", ijTL + m_TLr+1, k_iter );
288 #endif
289             n_deflationsBR = FLA_Bsvd_iteracc_v_opd_var1( m_BRr,
290                              n_GHr,
291                              ijTL + ijBRr,
292                              tol,
293                              thresh,
294                              dBR, inc_d,
295                              eBR, inc_e,
296                              GB,  rs_G, cs_G,
297                              HB,  rs_H, cs_H,
298                              &n_iter_perfBR );
299 #ifdef PRINTF
300             printf( "FLA_Bsvd_iteracc_v_opd_var1: returning: ijBRr m_BRr: %d %d\n", ijBRr, m_BRr );
301 #endif
302 
303             *n_iter_perf = k_iter + max( n_iter_perfTL, n_iter_perfBR );
304 
305             return n_deflations + n_deflationsTL + n_deflationsBR;
306         }
307 
308         /*------------------------------------------------------------*/
309     }
310 
311     // Skip 1x1 matrices (and submatrices) entirely.
312     if ( m_A > 1 )
313     {
314         dcomplex* g1 = buff_G + (k_iter)*cs_G;
315         dcomplex* h1 = buff_H + (k_iter)*cs_H;
316 
317         double*   alpha11 = buff_d + (0  )*inc_d;
318         double*   alpha12 = buff_e + (0  )*inc_e;
319         double*   alpha22 = buff_d + (1  )*inc_d;
320 
321         double    smin;
322         double    smax;
323 
324         double    gammaL;
325         double    sigmaL;
326         double    gammaR;
327         double    sigmaR;
328 
329         // Find the singular value decomposition of the remaining (or only)
330         // 2x2 submatrix.
331         FLA_Svv_2x2_opd( alpha11,
332                          alpha12,
333                          alpha22,
334                          &smin,
335                          &smax,
336                          &gammaL,
337                          &sigmaL,
338                          &gammaR,
339                          &sigmaR );
340 
341         *alpha11 = smax;
342         *alpha22 = smin;
343 
344         // Zero out the remaining diagonal.
345         *alpha12 = 0.0;
346 
347         // Store the rotations.
348         g1[0].real = gammaL;
349         g1[0].imag = sigmaL;
350         h1[0].real = gammaR;
351         h1[0].imag = sigmaR;
352 
353         // Update the local counters.
354         k_iter       += 1;
355         n_deflations += 1;
356 
357 #ifdef PRINTF
358         printf( "FLA_Bsvd_iteracc_v_opd_var1: Svv sval %22.15e in col %3d (n=%d) after %2d it  k_tot=%d/%d\n", buff_d[ 1*inc_d ], ijTL+1, 2, 1, k_iter, n_GH );
359         printf( "FLA_Bsvd_iteracc_v_opd_var1: Svv sval %22.15e in col %3d (n=%d) after %2d it  k_tot=%d/%d\n", buff_d[ 0*inc_d ], ijTL+0, 2, 0, k_iter, n_GH );
360 #endif
361     }
362 
363     *n_iter_perf = k_iter;
364     return n_deflations;
365 }
366