1 /*
2 
3     Copyright (C) 2014, The University of Texas at Austin
4 
5     This file is part of libflame and is available under the 3-Clause
6     BSD license, which can be found in the LICENSE file at the top-level
7     directory, or at http://opensource.org/licenses/BSD-3-Clause
8 
9 */
10 
11 #include "FLAME.h"
12 
13 #define FLA_ALG_REFERENCE     0
14 #define FLA_ALG_UNBLOCKED     1
15 #define FLA_ALG_UNB_OPT       2
16 #define FLA_ALG_BLOCKED       3
17 
18 
19 void time_Svd_uv_components(
20                int variant, int type, int n_repeats, int m, int n, int n_iter_max, int k_accum, int b_alg,
21                FLA_Obj A, FLA_Obj s, FLA_Obj U, FLA_Obj V,
22                double* dtime, double* diff1, double* diff2, double* gflops,
23                double* dtime_bred, double* gflops_bred,
24                double* dtime_bsvd, double* gflops_bsvd,
25                double* dtime_appq, double* gflops_appq,
26                double* dtime_qrfa, double* gflops_qrfa,
27                double* dtime_gemm, double* gflops_gemm, int* k_perf );
28 
29 
main(int argc,char * argv[])30 int main(int argc, char *argv[])
31 {
32   int
33     m_input, n_input,
34     m, n,
35     p_first, p_last, p_inc,
36     p,
37     min_m_n,
38     k_accum,
39     b_alg,
40     n_iter_max,
41     dist_type,
42     variant,
43     n_repeats,
44     i,
45     k_perf,
46     first_var = 1,
47     last_var  = 2;
48 
49   double
50     dtime,
51     gflops,
52     diff1,
53     diff2;
54   double
55     dtime_tred,
56     dtime_tevd,
57     dtime_appq,
58     dtime_qrfa,
59     dtime_gemm,
60     gflops_tred,
61     gflops_tevd,
62     gflops_appq,
63     gflops_qrfa,
64     gflops_gemm;
65 
66   FLA_Datatype datatype, dt_real;
67 
68   FLA_Obj
69     A, s, U, V, alpha, shift, nc;
70 
71 
72   FLA_Init();
73 
74 
75   fprintf( stdout, "%c number of repeats:", '%' );
76   scanf( "%d", &n_repeats );
77   fprintf( stdout, "%c %d\n", '%', n_repeats );
78 
79   fprintf( stdout, "%c enter n_iter_max:", '%' );
80   scanf( "%d", &n_iter_max );
81   fprintf( stdout, "%c %d\n", '%', n_iter_max );
82 
83   fprintf( stdout, "%c enter number of Givens rotations to accumulate:", '%' );
84   scanf( "%d", &k_accum );
85   fprintf( stdout, "%c %d\n", '%', k_accum );
86 
87   fprintf( stdout, "%c enter blocking size:", '%' );
88   scanf( "%d", &b_alg );
89   fprintf( stdout, "%c %d\n", '%', b_alg );
90 
91   fprintf( stdout, "%c enter problem size first, last, inc:", '%' );
92   scanf( "%d%d%d", &p_first, &p_last, &p_inc );
93   fprintf( stdout, "%c %d %d %d\n", '%', p_first, p_last, p_inc );
94 
95   fprintf( stdout, "%c enter m n (-1 means bind to problem size): ", '%' );
96   scanf( "%d %d", &m_input, &n_input );
97   fprintf( stdout, "%c %d %d\n", '%', m_input, n_input );
98 
99   fprintf( stdout, "%c enter distribution type (0=u, 1=i, 2=g, 3=l, 4=r, 5=c): ", '%' );
100   scanf( "%d", &dist_type );
101   fprintf( stdout, "%c %d\n", '%', dist_type );
102 
103 
104   fprintf( stdout, "\n" );
105 
106 
107 
108   for ( p = p_first, i = 1; p <= p_last; p += p_inc, i += 1 )
109   {
110 
111     m = m_input;
112     n = n_input;
113 
114     if( m < 0 ) m = p / abs(m_input);
115     if( n < 0 ) n = p / abs(n_input);
116 
117     min_m_n = min( m, n );
118 
119     //datatype = FLA_FLOAT;
120     //datatype = FLA_DOUBLE;
121     //datatype = FLA_COMPLEX;
122     datatype = FLA_DOUBLE_COMPLEX;
123 
124     FLA_Obj_create( datatype, m,       n, 0, 0, &A );
125     FLA_Obj_create( datatype, m,       m, 0, 0, &U );
126     FLA_Obj_create( datatype, n,       n, 0, 0, &V );
127 
128     dt_real = FLA_Obj_datatype_proj_to_real( A );
129 
130     FLA_Obj_create( dt_real,  min_m_n, 1, 0, 0, &s );
131     FLA_Obj_create( dt_real,  1,       1, 0, 0, &alpha );
132     FLA_Obj_create( dt_real,  1,       1, 0, 0, &shift );
133     FLA_Obj_create( FLA_INT,  1,       1, 0, 0, &nc );
134 
135     FLA_Random_unitary_matrix( U );
136     FLA_Random_unitary_matrix( V );
137 
138     if ( dist_type == 0 )
139     {
140       // Linear
141       *FLA_DOUBLE_PTR( shift ) = 0.0;
142       *FLA_DOUBLE_PTR( alpha ) = 1.0;
143       fprintf( stdout, "%c using linear dist.\n", '%' );
144       fprintf( stdout, "%c delta = %9.3e\n", '%', *FLA_DOUBLE_PTR( alpha ) );
145       FLA_Fill_with_linear_dist( shift, alpha, s );
146     }
147     else if ( dist_type == 1 )
148     {
149       // Inverse
150       *FLA_DOUBLE_PTR( alpha ) = 1.0;
151       fprintf( stdout, "%c using inverse dist.\n", '%' );
152       fprintf( stdout, "%c alpha = %9.3e\n", '%', *FLA_DOUBLE_PTR( alpha ) );
153       FLA_Fill_with_inverse_dist( alpha, s );
154     }
155     else if ( dist_type == 2 )
156     {
157       // Geometric
158       *FLA_DOUBLE_PTR( alpha ) = 1.0 / (double)min_m_n;
159       fprintf( stdout, "%c using geometric dist.\n", '%' );
160       fprintf( stdout, "%c alpha = %10.4e\n", '%', *FLA_DOUBLE_PTR( alpha ) );
161       FLA_Fill_with_geometric_dist( alpha, s );
162     }
163     else if ( dist_type == 3 )
164     {
165       // Logarithmic
166       *FLA_DOUBLE_PTR( alpha ) = 1.20;
167       fprintf( stdout, "%c using logarithmic dist.\n", '%' );
168       fprintf( stdout, "%c alpha = %9.3e\n", '%', *FLA_DOUBLE_PTR( alpha ) );
169       FLA_Fill_with_logarithmic_dist( alpha, s );
170     }
171     else if ( dist_type == 4 )
172     {
173       // Random
174       *FLA_DOUBLE_PTR( shift ) = 0.0;
175       *FLA_DOUBLE_PTR( alpha ) = (double)min_m_n;
176       fprintf( stdout, "%c using random dist.\n", '%' );
177       fprintf( stdout, "%c shift = %13.8e\n", '%', *FLA_DOUBLE_PTR( shift ) );
178       fprintf( stdout, "%c alpha = %9.3e\n", '%', *FLA_DOUBLE_PTR( alpha ) );
179       FLA_Fill_with_random_dist( shift, alpha, s );
180     }
181     else if ( dist_type == 5 )
182     {
183       // Cluster
184       *FLA_INT_PTR( nc )       = 10;
185       *FLA_DOUBLE_PTR( alpha ) = 1.0e-9;
186       fprintf( stdout, "%c using cluster dist.\n", '%' );
187       fprintf( stdout, "%c num clusters  = %d\n", '%', *FLA_INT_PTR( nc ) );
188       fprintf( stdout, "%c cluster width = %9.3e\n", '%', *FLA_DOUBLE_PTR( alpha ) );
189       FLA_Fill_with_cluster_dist( nc, alpha, s );
190     }
191 
192     {
193       FLA_Obj UL, UR;
194       FLA_Obj VL, VR;
195 
196       FLA_Part_1x2( U,   &UL, &UR,   min_m_n, FLA_LEFT );
197       FLA_Part_1x2( V,   &VL, &VR,   min_m_n, FLA_LEFT );
198 
199       FLA_Apply_diag_matrix( FLA_RIGHT, FLA_NO_CONJUGATE, s, UL );
200       FLA_Gemm( FLA_NO_TRANSPOSE, FLA_CONJ_TRANSPOSE,
201                 FLA_ONE, UL, VL, FLA_ZERO, A );
202     }
203 
204     FLA_Set( FLA_ZERO, s );
205     FLA_Set( FLA_ZERO, U );
206     FLA_Set( FLA_ZERO, V );
207 
208     fprintf( stdout, "%c                               total----------                       reduction------    bi svd---------   form/app QUV---   QR fact--------   gemm-----------\n", '%' );
209     fprintf( stdout, "%c                         %4s  gflops dtime      resid    |I-QQ'|    gflops dtime       gflops dtime      gflops dtime      gflops dtime      gflops dtime     niter\n", '%', "p" );
210 
211     time_Svd_uv_components( -2, FLA_ALG_REFERENCE, n_repeats,
212                             m, n, n_iter_max, k_accum, b_alg,
213                             A, s, U, V, &dtime, &diff1, &diff2, &gflops,
214                             &dtime_tred, &gflops_tred,
215                             &dtime_tevd, &gflops_tevd,
216                             &dtime_appq, &gflops_appq,
217                             &dtime_qrfa, &gflops_qrfa,
218                             &dtime_gemm, &gflops_gemm, &k_perf );
219     if ( dtime_tred == 1.0 ) gflops_tred = gflops_tevd = gflops_appq = gflops_qrfa = gflops_gemm =
220                               dtime_tred =  dtime_tevd =  dtime_appq =  dtime_qrfa =  dtime_gemm = 0.0;
221 
222     fprintf( stdout, "data_refq( %2d, 1:16 ) = [ %4d %6.3lf %9.2e   %6.2le %6.2le  %6.3lf %9.2e  %7.3lf %9.2e  %6.3lf %9.2e  %6.3lf %9.2e  %6.3lf %9.2e %5d ];\n",
223              i, p, gflops, dtime, diff1, diff2,
224              gflops_tred, dtime_tred,
225              gflops_tevd, dtime_tevd,
226              gflops_appq, dtime_appq,
227              gflops_qrfa, dtime_qrfa,
228              gflops_gemm, dtime_gemm, k_perf );
229     fflush( stdout );
230 
231     time_Svd_uv_components( -3, FLA_ALG_REFERENCE, n_repeats,
232                             m, n, n_iter_max, k_accum, b_alg,
233                             A, s, U, V, &dtime, &diff1, &diff2, &gflops,
234                             &dtime_tred, &gflops_tred,
235                             &dtime_tevd, &gflops_tevd,
236                             &dtime_appq, &gflops_appq,
237                             &dtime_qrfa, &gflops_qrfa,
238                             &dtime_gemm, &gflops_gemm, &k_perf );
239     if ( dtime_tred == 1.0 ) gflops_tred = gflops_tevd = gflops_appq = gflops_qrfa = gflops_gemm =
240                               dtime_tred =  dtime_tevd =  dtime_appq =  dtime_qrfa =  dtime_gemm = 0.0;
241 
242     fprintf( stdout, "data_refd( %2d, 1:16 ) = [ %4d %6.3lf %9.2e   %6.2le %6.2le  %6.3lf %9.2e  %7.3lf %9.2e  %6.3lf %9.2e  %6.3lf %9.2e  %6.3lf %9.2e %5d ];\n",
243              i, p, gflops, dtime, diff1, diff2,
244              gflops_tred, dtime_tred,
245              gflops_tevd, dtime_tevd,
246              gflops_appq, dtime_appq,
247              gflops_qrfa, dtime_qrfa,
248              gflops_gemm, dtime_gemm, k_perf );
249     fflush( stdout );
250 
251     time_Svd_uv_components( 0, FLA_ALG_REFERENCE, n_repeats,
252                             m, n, n_iter_max, k_accum, b_alg,
253                             A, s, U, V, &dtime, &diff1, &diff2, &gflops,
254                             &dtime_tred, &gflops_tred,
255                             &dtime_tevd, &gflops_tevd,
256                             &dtime_appq, &gflops_appq,
257                             &dtime_qrfa, &gflops_qrfa,
258                             &dtime_gemm, &gflops_gemm, &k_perf );
259     if ( dtime_tred == 1.0 ) gflops_tred = gflops_tevd = gflops_appq = gflops_qrfa = gflops_gemm =
260                               dtime_tred =  dtime_tevd =  dtime_appq =  dtime_qrfa =  dtime_gemm = 0.0;
261 
262     fprintf( stdout, "data_REFq( %2d, 1:16 ) = [ %4d %6.3lf %9.2e   %6.2le %6.2le  %6.3lf %9.2e  %7.3lf %9.2e  %6.3lf %9.2e  %6.3lf %9.2e  %6.3lf %9.2e %5d ];\n",
263              i, p, gflops, dtime, diff1, diff2,
264              gflops_tred, dtime_tred,
265              gflops_tevd, dtime_tevd,
266              gflops_appq, dtime_appq,
267              gflops_qrfa, dtime_qrfa,
268              gflops_gemm, dtime_gemm, k_perf );
269     fflush( stdout );
270 
271     time_Svd_uv_components( -1, FLA_ALG_REFERENCE, n_repeats,
272                             m, n, n_iter_max, k_accum, b_alg,
273                             A, s, U, V, &dtime, &diff1, &diff2, &gflops,
274                             &dtime_tred, &gflops_tred,
275                             &dtime_tevd, &gflops_tevd,
276                             &dtime_appq, &gflops_appq,
277                             &dtime_qrfa, &gflops_qrfa,
278                             &dtime_gemm, &gflops_gemm, &k_perf );
279     if ( dtime_tred == 1.0 ) gflops_tred = gflops_tevd = gflops_appq = gflops_qrfa = gflops_gemm =
280                               dtime_tred =  dtime_tevd =  dtime_appq =  dtime_qrfa =  dtime_gemm = 0.0;
281 
282     fprintf( stdout, "data_REFd( %2d, 1:16 ) = [ %4d %6.3lf %9.2e   %6.2le %6.2le  %6.3lf %9.2e  %7.3lf %9.2e  %6.3lf %9.2e  %6.3lf %9.2e  %6.3lf %9.2e %5d ];\n",
283              i, p, gflops, dtime, diff1, diff2,
284              gflops_tred, dtime_tred,
285              gflops_tevd, dtime_tevd,
286              gflops_appq, dtime_appq,
287              gflops_qrfa, dtime_qrfa,
288              gflops_gemm, dtime_gemm, k_perf );
289     fflush( stdout );
290 
291 
292     for ( variant = first_var; variant <= last_var; variant++ ){
293 
294       fprintf( stdout, "data_var%d( %2d, 1:16 ) = [ %4d ", variant, i, p );
295       fflush( stdout );
296 
297       time_Svd_uv_components( variant, FLA_ALG_UNBLOCKED, n_repeats,
298                               m, n, n_iter_max, k_accum, b_alg,
299                               A, s, U, V, &dtime, &diff1, &diff2, &gflops,
300                               &dtime_tred, &gflops_tred,
301                               &dtime_tevd, &gflops_tevd,
302                               &dtime_appq, &gflops_appq,
303                               &dtime_qrfa, &gflops_qrfa,
304                               &dtime_gemm, &gflops_gemm, &k_perf );
305 
306       fprintf( stdout, "%6.3lf %9.2e   %6.2le %6.2le  %6.3lf %9.2e  %7.3lf %9.2e  %6.3lf %9.2e  %6.3lf %9.2e  %6.3lf %9.2e %5d ",
307                gflops, dtime, diff1, diff2,
308                gflops_tred, dtime_tred,
309                gflops_tevd, dtime_tevd,
310                gflops_appq, dtime_appq,
311                gflops_qrfa, dtime_qrfa,
312                gflops_gemm, dtime_gemm, k_perf );
313 
314       fprintf( stdout, "];\n" );
315       fflush( stdout );
316     }
317 
318     fprintf( stdout, "\n" );
319 
320     FLA_Obj_free( &A );
321     FLA_Obj_free( &U );
322     FLA_Obj_free( &V );
323     FLA_Obj_free( &s );
324     FLA_Obj_free( &alpha );
325     FLA_Obj_free( &shift );
326     FLA_Obj_free( &nc );
327   }
328 
329   FLA_Finalize( );
330 
331   return 0;
332 }
333 
334