1 /*
2 
3     Copyright (C) 2014, The University of Texas at Austin
4 
5     This file is part of libflame and is available under the 3-Clause
6     BSD license, which can be found in the LICENSE file at the top-level
7     directory, or at http://opensource.org/licenses/BSD-3-Clause
8 
9 */
10 
11 #include "FLAME.h"
12 
FLA_Svv_2x2(FLA_Obj alpha11,FLA_Obj alpha12,FLA_Obj alpha22,FLA_Obj sigma1,FLA_Obj sigma2,FLA_Obj gammaL,FLA_Obj sigmaL,FLA_Obj gammaR,FLA_Obj sigmaR)13 FLA_Error FLA_Svv_2x2( FLA_Obj alpha11, FLA_Obj alpha12, FLA_Obj alpha22,
14                        FLA_Obj sigma1, FLA_Obj sigma2,
15                        FLA_Obj gammaL, FLA_Obj sigmaL,
16                        FLA_Obj gammaR, FLA_Obj sigmaR )
17 /*
18   Compute the singular value decomposition of a 2x2 triangular matrix A
19   such that
20 
21     / alpha11 alpha12 \
22     \    0    alpha22 /
23 
24   is equal to
25 
26     / gammaL -sigmaL \ / sigma1    0    \ / gammaR -sigmaR \'
27     \ sigmaL  gammaL / \    0    sigma2 / \ sigmaR  gammaR /
28 
29   Upon completion, sigma1 and sigma2 are overwritten with the
30   singular values of smaller and larger absolute values, respectively,
31   while gammaL, sigmaL, gammaR, and sigmaR determine the corresponding
32   left and right singular vector elements.
33 
34   This routine is a nearly-verbatim translation of slasv2() and dlasv2()
35   from the netlib distribution of LAPACK.
36 
37   -FGVZ
38 */
39 {
40     FLA_Datatype datatype;
41 
42     datatype = FLA_Obj_datatype( alpha11 );
43 
44     switch ( datatype )
45     {
46     case FLA_FLOAT:
47     {
48         float*  buff_alpha11 = FLA_FLOAT_PTR( alpha11 );
49         float*  buff_alpha12 = FLA_FLOAT_PTR( alpha12 );
50         float*  buff_alpha22 = FLA_FLOAT_PTR( alpha22 );
51         float*  buff_sigma1  = FLA_FLOAT_PTR( sigma1 );
52         float*  buff_sigma2  = FLA_FLOAT_PTR( sigma2 );
53         float*  buff_gammaL  = FLA_FLOAT_PTR( gammaL );
54         float*  buff_sigmaL  = FLA_FLOAT_PTR( sigmaL );
55         float*  buff_gammaR  = FLA_FLOAT_PTR( gammaR );
56         float*  buff_sigmaR  = FLA_FLOAT_PTR( sigmaR );
57 
58         FLA_Svv_2x2_ops( buff_alpha11,
59                          buff_alpha12,
60                          buff_alpha22,
61                          buff_sigma1,
62                          buff_sigma2,
63                          buff_gammaL,
64                          buff_sigmaL,
65                          buff_gammaR,
66                          buff_sigmaR );
67 
68         break;
69     }
70 
71     case FLA_DOUBLE:
72     {
73         double* buff_alpha11 = FLA_DOUBLE_PTR( alpha11 );
74         double* buff_alpha12 = FLA_DOUBLE_PTR( alpha12 );
75         double* buff_alpha22 = FLA_DOUBLE_PTR( alpha22 );
76         double* buff_sigma1  = FLA_DOUBLE_PTR( sigma1 );
77         double* buff_sigma2  = FLA_DOUBLE_PTR( sigma2 );
78         double* buff_gammaL  = FLA_DOUBLE_PTR( gammaL );
79         double* buff_sigmaL  = FLA_DOUBLE_PTR( sigmaL );
80         double* buff_gammaR  = FLA_DOUBLE_PTR( gammaR );
81         double* buff_sigmaR  = FLA_DOUBLE_PTR( sigmaR );
82 
83         FLA_Svv_2x2_opd( buff_alpha11,
84                          buff_alpha12,
85                          buff_alpha22,
86                          buff_sigma1,
87                          buff_sigma2,
88                          buff_gammaL,
89                          buff_sigmaL,
90                          buff_gammaR,
91                          buff_sigmaR );
92 
93         break;
94     }
95     }
96 
97     return FLA_SUCCESS;
98 }
99 
100 
101 
FLA_Svv_2x2_ops(float * alpha11,float * alpha12,float * alpha22,float * sigma1,float * sigma2,float * gammaL,float * sigmaL,float * gammaR,float * sigmaR)102 FLA_Error FLA_Svv_2x2_ops( float*    alpha11,
103                            float*    alpha12,
104                            float*    alpha22,
105                            float*    sigma1,
106                            float*    sigma2,
107                            float*    gammaL,
108                            float*    sigmaL,
109                            float*    gammaR,
110                            float*    sigmaR )
111 {
112     float  zero = 0.0F;
113     float  half = 0.5F;
114     float  one  = 1.0F;
115     float  two  = 2.0F;
116     float  four = 4.0F;
117 
118     float  eps;
119 
120     float  f, g, h;
121     float  clt, crt, slt, srt;
122     float  a, d, fa, ft, ga, gt, ha, ht, l;
123     float  m, mm, r, s, t, temp, tsign, tt;
124     float  ssmin, ssmax;
125     float  csl, snl;
126     float  csr, snr;
127 
128     int    gasmal, swap;
129     int    pmax;
130 
131     f = *alpha11;
132     g = *alpha12;
133     h = *alpha22;
134 
135     eps = FLA_Mach_params_ops( FLA_MACH_EPS );
136 
137     ft = f;
138     fa = fabsf( f );
139     ht = h;
140     ha = fabsf( h );
141 
142     // pmax points to the maximum absolute element of matrix.
143     //   pmax = 1 if f largest in absolute values.
144     //   pmax = 2 if g largest in absolute values.
145     //   pmax = 3 if h largest in absolute values.
146 
147     pmax = 1;
148 
149     swap = ( ha > fa );
150     if ( swap )
151     {
152         pmax = 3;
153 
154         temp = ft;
155         ft = ht;
156         ht = temp;
157 
158         temp = fa;
159         fa = ha;
160         ha = temp;
161     }
162 
163     gt = g;
164     ga = fabsf( g );
165 
166     if ( ga == zero )
167     {
168         // Diagonal matrix case.
169 
170         ssmin = ha;
171         ssmax = fa;
172         clt   = one;
173         slt   = zero;
174         crt   = one;
175         srt   = zero;
176     }
177     else
178     {
179         gasmal = TRUE;
180 
181         if ( ga > fa )
182         {
183             pmax = 2;
184 
185             if ( ( fa / ga ) < eps )
186             {
187                 // Case of very large ga.
188 
189                 gasmal = FALSE;
190 
191                 ssmax  = ga;
192 
193                 if ( ha > one ) ssmin = fa / ( ga / ha );
194                 else            ssmin = ( fa / ga ) * ha;
195 
196                 clt = one;
197                 slt = ht / gt;
198                 crt = ft / gt;
199                 srt = one;
200             }
201         }
202 
203         if ( gasmal )
204         {
205             // Normal case.
206 
207             d = fa - ha;
208 
209             if ( d == fa ) l = one;
210             else           l = d / fa;
211 
212             m = gt / ft;
213 
214             t = two - l;
215 
216             mm = m * m;
217             tt = t * t;
218             s = sqrtf( tt + mm );
219 
220             if ( l == zero ) r = fabsf( m );
221             else             r = sqrtf( l * l + mm );
222 
223             a = half * ( s + r );
224 
225             ssmin = ha / a;
226             ssmax = fa * a;
227 
228             if ( mm == zero )
229             {
230                 // Here, m is tiny.
231 
232                 if ( l == zero ) t = signof( two, ft ) * signof( one, gt );
233                 else             t = gt / signof( d, ft ) + m / t;
234             }
235             else
236             {
237                 t = ( m / ( s + t ) + m / ( r + l ) ) * ( one + a );
238             }
239 
240             l = sqrtf( t*t + four );
241             crt = two / l;
242             srt = t / l;
243             clt = ( crt + srt * m ) / a;
244             slt = ( ht / ft ) * srt / a;
245         }
246     }
247 
248     if ( swap )
249     {
250         csl = srt;
251         snl = crt;
252         csr = slt;
253         snr = clt;
254     }
255     else
256     {
257         csl = clt;
258         snl = slt;
259         csr = crt;
260         snr = srt;
261     }
262 
263 
264     // Correct the signs of ssmax and ssmin.
265 
266     if      ( pmax == 1 )
267         tsign = signof( one, csr ) * signof( one, csl ) * signof( one, f );
268     else if ( pmax == 2 )
269         tsign = signof( one, snr ) * signof( one, csl ) * signof( one, g );
270     else // if ( pmax == 3 )
271         tsign = signof( one, snr ) * signof( one, snl ) * signof( one, h );
272 
273     ssmax = signof( ssmax, tsign );
274     ssmin = signof( ssmin, tsign * signof( one, f ) * signof( one, h ) );
275 
276     // Save the output values.
277 
278     *sigma1 = ssmin;
279     *sigma2 = ssmax;
280     *gammaL = csl;
281     *sigmaL = snl;
282     *gammaR = csr;
283     *sigmaR = snr;
284 
285     return FLA_SUCCESS;
286 }
287 
288 
289 
FLA_Svv_2x2_opd(double * alpha11,double * alpha12,double * alpha22,double * sigma1,double * sigma2,double * gammaL,double * sigmaL,double * gammaR,double * sigmaR)290 FLA_Error FLA_Svv_2x2_opd( double*   alpha11,
291                            double*   alpha12,
292                            double*   alpha22,
293                            double*   sigma1,
294                            double*   sigma2,
295                            double*   gammaL,
296                            double*   sigmaL,
297                            double*   gammaR,
298                            double*   sigmaR )
299 {
300     double zero = 0.0;
301     double half = 0.5;
302     double one  = 1.0;
303     double two  = 2.0;
304     double four = 4.0;
305 
306     double eps;
307 
308     double f, g, h;
309     double clt, crt, slt, srt;
310     double a, d, fa, ft, ga, gt, ha, ht, l;
311     double m, mm, r, s, t, temp, tsign, tt;
312     double ssmin, ssmax;
313     double csl, snl;
314     double csr, snr;
315 
316     int    gasmal, swap;
317     int    pmax;
318 
319     f = *alpha11;
320     g = *alpha12;
321     h = *alpha22;
322 
323     eps = FLA_Mach_params_opd( FLA_MACH_EPS );
324 
325     ft = f;
326     fa = fabs( f );
327     ht = h;
328     ha = fabs( h );
329 
330     // pmax points to the maximum absolute element of matrix.
331     //   pmax = 1 if f largest in absolute values.
332     //   pmax = 2 if g largest in absolute values.
333     //   pmax = 3 if h largest in absolute values.
334 
335     pmax = 1;
336 
337     swap = ( ha > fa );
338     if ( swap )
339     {
340         pmax = 3;
341 
342         temp = ft;
343         ft = ht;
344         ht = temp;
345 
346         temp = fa;
347         fa = ha;
348         ha = temp;
349     }
350 
351     gt = g;
352     ga = fabs( g );
353 
354     if ( ga == zero )
355     {
356         // Diagonal matrix case.
357 
358         ssmin = ha;
359         ssmax = fa;
360         clt   = one;
361         slt   = zero;
362         crt   = one;
363         srt   = zero;
364     }
365     else
366     {
367         gasmal = TRUE;
368 
369         if ( ga > fa )
370         {
371             pmax = 2;
372 
373             if ( ( fa / ga ) < eps )
374             {
375                 // Case of very large ga.
376 
377                 gasmal = FALSE;
378 
379                 ssmax  = ga;
380 
381                 if ( ha > one ) ssmin = fa / ( ga / ha );
382                 else            ssmin = ( fa / ga ) * ha;
383 
384                 clt = one;
385                 slt = ht / gt;
386                 crt = ft / gt;
387                 srt = one;
388             }
389         }
390 
391         if ( gasmal )
392         {
393             // Normal case.
394 
395             d = fa - ha;
396 
397             if ( d == fa ) l = one;
398             else           l = d / fa;
399 
400             m = gt / ft;
401 
402             t = two - l;
403 
404             mm = m * m;
405             tt = t * t;
406             s = sqrt( tt + mm );
407 
408             if ( l == zero ) r = fabs( m );
409             else             r = sqrt( l * l + mm );
410 
411             a = half * ( s + r );
412 
413             ssmin = ha / a;
414             ssmax = fa * a;
415 
416             if ( mm == zero )
417             {
418                 // Here, m is tiny.
419 
420                 if ( l == zero ) t = signof( two, ft ) * signof( one, gt );
421                 else             t = gt / signof( d, ft ) + m / t;
422             }
423             else
424             {
425                 t = ( m / ( s + t ) + m / ( r + l ) ) * ( one + a );
426             }
427 
428             l = sqrt( t*t + four );
429             crt = two / l;
430             srt = t / l;
431             clt = ( crt + srt * m ) / a;
432             slt = ( ht / ft ) * srt / a;
433         }
434     }
435 
436     if ( swap )
437     {
438         csl = srt;
439         snl = crt;
440         csr = slt;
441         snr = clt;
442     }
443     else
444     {
445         csl = clt;
446         snl = slt;
447         csr = crt;
448         snr = srt;
449     }
450 
451 
452     // Correct the signs of ssmax and ssmin.
453 
454     if      ( pmax == 1 )
455         tsign = signof( one, csr ) * signof( one, csl ) * signof( one, f );
456     else if ( pmax == 2 )
457         tsign = signof( one, snr ) * signof( one, csl ) * signof( one, g );
458     else // if ( pmax == 3 )
459         tsign = signof( one, snr ) * signof( one, snl ) * signof( one, h );
460 
461     ssmax = signof( ssmax, tsign );
462     ssmin = signof( ssmin, tsign * signof( one, f ) * signof( one, h ) );
463 
464     // Save the output values.
465 
466     *sigma1 = ssmin;
467     *sigma2 = ssmax;
468     *gammaL = csl;
469     *sigmaL = snl;
470     *gammaR = csr;
471     *sigmaR = snr;
472 
473     return FLA_SUCCESS;
474 }
475 
476