1 /*
2
3 Copyright (C) 2014, The University of Texas at Austin
4
5 This file is part of libflame and is available under the 3-Clause
6 BSD license, which can be found in the LICENSE file at the top-level
7 directory, or at http://opensource.org/licenses/BSD-3-Clause
8
9 */
10
11 #include "FLAME.h"
12
13 #define FLA_ALG_REFERENCE 0
14 #define FLA_ALG_UNBLOCKED 1
15 #define FLA_ALG_UNB_OPT 2
16 #define FLA_ALG_BLOCKED 3
17
18
19 void time_Svd_uv_components(
20 int variant, int type, int n_repeats, int m, int n, int n_iter_max, int k_accum, int b_alg,
21 FLA_Obj A, FLA_Obj s, FLA_Obj U, FLA_Obj V,
22 double* dtime, double* diff1, double* diff2, double* gflops,
23 double* dtime_bred, double* gflops_bred,
24 double* dtime_bsvd, double* gflops_bsvd,
25 double* dtime_appq, double* gflops_appq,
26 double* dtime_qrfa, double* gflops_qrfa,
27 double* dtime_gemm, double* gflops_gemm, int* k_perf );
28
29
main(int argc,char * argv[])30 int main(int argc, char *argv[])
31 {
32 int
33 m_input, n_input,
34 m, n,
35 p_first, p_last, p_inc,
36 p,
37 min_m_n,
38 k_accum,
39 b_alg,
40 n_iter_max,
41 dist_type,
42 variant,
43 n_repeats,
44 i,
45 k_perf,
46 first_var = 1,
47 last_var = 2;
48
49 double
50 dtime,
51 gflops,
52 diff1,
53 diff2;
54 double
55 dtime_tred,
56 dtime_tevd,
57 dtime_appq,
58 dtime_qrfa,
59 dtime_gemm,
60 gflops_tred,
61 gflops_tevd,
62 gflops_appq,
63 gflops_qrfa,
64 gflops_gemm;
65
66 FLA_Datatype datatype, dt_real;
67
68 FLA_Obj
69 A, s, U, V, alpha, shift, nc;
70
71
72 FLA_Init();
73
74
75 fprintf( stdout, "%c number of repeats:", '%' );
76 scanf( "%d", &n_repeats );
77 fprintf( stdout, "%c %d\n", '%', n_repeats );
78
79 fprintf( stdout, "%c enter n_iter_max:", '%' );
80 scanf( "%d", &n_iter_max );
81 fprintf( stdout, "%c %d\n", '%', n_iter_max );
82
83 fprintf( stdout, "%c enter number of Givens rotations to accumulate:", '%' );
84 scanf( "%d", &k_accum );
85 fprintf( stdout, "%c %d\n", '%', k_accum );
86
87 fprintf( stdout, "%c enter blocking size:", '%' );
88 scanf( "%d", &b_alg );
89 fprintf( stdout, "%c %d\n", '%', b_alg );
90
91 fprintf( stdout, "%c enter problem size first, last, inc:", '%' );
92 scanf( "%d%d%d", &p_first, &p_last, &p_inc );
93 fprintf( stdout, "%c %d %d %d\n", '%', p_first, p_last, p_inc );
94
95 fprintf( stdout, "%c enter m n (-1 means bind to problem size): ", '%' );
96 scanf( "%d %d", &m_input, &n_input );
97 fprintf( stdout, "%c %d %d\n", '%', m_input, n_input );
98
99 fprintf( stdout, "%c enter distribution type (0=u, 1=i, 2=g, 3=l, 4=r, 5=c): ", '%' );
100 scanf( "%d", &dist_type );
101 fprintf( stdout, "%c %d\n", '%', dist_type );
102
103
104 fprintf( stdout, "\n" );
105
106
107
108 for ( p = p_first, i = 1; p <= p_last; p += p_inc, i += 1 )
109 {
110
111 m = m_input;
112 n = n_input;
113
114 if( m < 0 ) m = p / abs(m_input);
115 if( n < 0 ) n = p / abs(n_input);
116
117 min_m_n = min( m, n );
118
119 //datatype = FLA_FLOAT;
120 //datatype = FLA_DOUBLE;
121 //datatype = FLA_COMPLEX;
122 datatype = FLA_DOUBLE_COMPLEX;
123
124 FLA_Obj_create( datatype, m, n, 0, 0, &A );
125 FLA_Obj_create( datatype, m, m, 0, 0, &U );
126 FLA_Obj_create( datatype, n, n, 0, 0, &V );
127
128 dt_real = FLA_Obj_datatype_proj_to_real( A );
129
130 FLA_Obj_create( dt_real, min_m_n, 1, 0, 0, &s );
131 FLA_Obj_create( dt_real, 1, 1, 0, 0, &alpha );
132 FLA_Obj_create( dt_real, 1, 1, 0, 0, &shift );
133 FLA_Obj_create( FLA_INT, 1, 1, 0, 0, &nc );
134
135 FLA_Random_unitary_matrix( U );
136 FLA_Random_unitary_matrix( V );
137
138 if ( dist_type == 0 )
139 {
140 // Linear
141 *FLA_DOUBLE_PTR( shift ) = 0.0;
142 *FLA_DOUBLE_PTR( alpha ) = 1.0;
143 fprintf( stdout, "%c using linear dist.\n", '%' );
144 fprintf( stdout, "%c delta = %9.3e\n", '%', *FLA_DOUBLE_PTR( alpha ) );
145 FLA_Fill_with_linear_dist( shift, alpha, s );
146 }
147 else if ( dist_type == 1 )
148 {
149 // Inverse
150 *FLA_DOUBLE_PTR( alpha ) = 1.0;
151 fprintf( stdout, "%c using inverse dist.\n", '%' );
152 fprintf( stdout, "%c alpha = %9.3e\n", '%', *FLA_DOUBLE_PTR( alpha ) );
153 FLA_Fill_with_inverse_dist( alpha, s );
154 }
155 else if ( dist_type == 2 )
156 {
157 // Geometric
158 *FLA_DOUBLE_PTR( alpha ) = 1.0 / (double)min_m_n;
159 fprintf( stdout, "%c using geometric dist.\n", '%' );
160 fprintf( stdout, "%c alpha = %10.4e\n", '%', *FLA_DOUBLE_PTR( alpha ) );
161 FLA_Fill_with_geometric_dist( alpha, s );
162 }
163 else if ( dist_type == 3 )
164 {
165 // Logarithmic
166 *FLA_DOUBLE_PTR( alpha ) = 1.20;
167 fprintf( stdout, "%c using logarithmic dist.\n", '%' );
168 fprintf( stdout, "%c alpha = %9.3e\n", '%', *FLA_DOUBLE_PTR( alpha ) );
169 FLA_Fill_with_logarithmic_dist( alpha, s );
170 }
171 else if ( dist_type == 4 )
172 {
173 // Random
174 *FLA_DOUBLE_PTR( shift ) = 0.0;
175 *FLA_DOUBLE_PTR( alpha ) = (double)min_m_n;
176 fprintf( stdout, "%c using random dist.\n", '%' );
177 fprintf( stdout, "%c shift = %13.8e\n", '%', *FLA_DOUBLE_PTR( shift ) );
178 fprintf( stdout, "%c alpha = %9.3e\n", '%', *FLA_DOUBLE_PTR( alpha ) );
179 FLA_Fill_with_random_dist( shift, alpha, s );
180 }
181 else if ( dist_type == 5 )
182 {
183 // Cluster
184 *FLA_INT_PTR( nc ) = 10;
185 *FLA_DOUBLE_PTR( alpha ) = 1.0e-9;
186 fprintf( stdout, "%c using cluster dist.\n", '%' );
187 fprintf( stdout, "%c num clusters = %d\n", '%', *FLA_INT_PTR( nc ) );
188 fprintf( stdout, "%c cluster width = %9.3e\n", '%', *FLA_DOUBLE_PTR( alpha ) );
189 FLA_Fill_with_cluster_dist( nc, alpha, s );
190 }
191
192 {
193 FLA_Obj UL, UR;
194 FLA_Obj VL, VR;
195
196 FLA_Part_1x2( U, &UL, &UR, min_m_n, FLA_LEFT );
197 FLA_Part_1x2( V, &VL, &VR, min_m_n, FLA_LEFT );
198
199 FLA_Apply_diag_matrix( FLA_RIGHT, FLA_NO_CONJUGATE, s, UL );
200 FLA_Gemm( FLA_NO_TRANSPOSE, FLA_CONJ_TRANSPOSE,
201 FLA_ONE, UL, VL, FLA_ZERO, A );
202 }
203
204 FLA_Set( FLA_ZERO, s );
205 FLA_Set( FLA_ZERO, U );
206 FLA_Set( FLA_ZERO, V );
207
208 fprintf( stdout, "%c total---------- reduction------ bi svd--------- form/app QUV--- QR fact-------- gemm-----------\n", '%' );
209 fprintf( stdout, "%c %4s gflops dtime resid |I-QQ'| gflops dtime gflops dtime gflops dtime gflops dtime gflops dtime niter\n", '%', "p" );
210
211 time_Svd_uv_components( -2, FLA_ALG_REFERENCE, n_repeats,
212 m, n, n_iter_max, k_accum, b_alg,
213 A, s, U, V, &dtime, &diff1, &diff2, &gflops,
214 &dtime_tred, &gflops_tred,
215 &dtime_tevd, &gflops_tevd,
216 &dtime_appq, &gflops_appq,
217 &dtime_qrfa, &gflops_qrfa,
218 &dtime_gemm, &gflops_gemm, &k_perf );
219 if ( dtime_tred == 1.0 ) gflops_tred = gflops_tevd = gflops_appq = gflops_qrfa = gflops_gemm =
220 dtime_tred = dtime_tevd = dtime_appq = dtime_qrfa = dtime_gemm = 0.0;
221
222 fprintf( stdout, "data_refq( %2d, 1:16 ) = [ %4d %6.3lf %9.2e %6.2le %6.2le %6.3lf %9.2e %7.3lf %9.2e %6.3lf %9.2e %6.3lf %9.2e %6.3lf %9.2e %5d ];\n",
223 i, p, gflops, dtime, diff1, diff2,
224 gflops_tred, dtime_tred,
225 gflops_tevd, dtime_tevd,
226 gflops_appq, dtime_appq,
227 gflops_qrfa, dtime_qrfa,
228 gflops_gemm, dtime_gemm, k_perf );
229 fflush( stdout );
230
231 time_Svd_uv_components( -3, FLA_ALG_REFERENCE, n_repeats,
232 m, n, n_iter_max, k_accum, b_alg,
233 A, s, U, V, &dtime, &diff1, &diff2, &gflops,
234 &dtime_tred, &gflops_tred,
235 &dtime_tevd, &gflops_tevd,
236 &dtime_appq, &gflops_appq,
237 &dtime_qrfa, &gflops_qrfa,
238 &dtime_gemm, &gflops_gemm, &k_perf );
239 if ( dtime_tred == 1.0 ) gflops_tred = gflops_tevd = gflops_appq = gflops_qrfa = gflops_gemm =
240 dtime_tred = dtime_tevd = dtime_appq = dtime_qrfa = dtime_gemm = 0.0;
241
242 fprintf( stdout, "data_refd( %2d, 1:16 ) = [ %4d %6.3lf %9.2e %6.2le %6.2le %6.3lf %9.2e %7.3lf %9.2e %6.3lf %9.2e %6.3lf %9.2e %6.3lf %9.2e %5d ];\n",
243 i, p, gflops, dtime, diff1, diff2,
244 gflops_tred, dtime_tred,
245 gflops_tevd, dtime_tevd,
246 gflops_appq, dtime_appq,
247 gflops_qrfa, dtime_qrfa,
248 gflops_gemm, dtime_gemm, k_perf );
249 fflush( stdout );
250
251 time_Svd_uv_components( 0, FLA_ALG_REFERENCE, n_repeats,
252 m, n, n_iter_max, k_accum, b_alg,
253 A, s, U, V, &dtime, &diff1, &diff2, &gflops,
254 &dtime_tred, &gflops_tred,
255 &dtime_tevd, &gflops_tevd,
256 &dtime_appq, &gflops_appq,
257 &dtime_qrfa, &gflops_qrfa,
258 &dtime_gemm, &gflops_gemm, &k_perf );
259 if ( dtime_tred == 1.0 ) gflops_tred = gflops_tevd = gflops_appq = gflops_qrfa = gflops_gemm =
260 dtime_tred = dtime_tevd = dtime_appq = dtime_qrfa = dtime_gemm = 0.0;
261
262 fprintf( stdout, "data_REFq( %2d, 1:16 ) = [ %4d %6.3lf %9.2e %6.2le %6.2le %6.3lf %9.2e %7.3lf %9.2e %6.3lf %9.2e %6.3lf %9.2e %6.3lf %9.2e %5d ];\n",
263 i, p, gflops, dtime, diff1, diff2,
264 gflops_tred, dtime_tred,
265 gflops_tevd, dtime_tevd,
266 gflops_appq, dtime_appq,
267 gflops_qrfa, dtime_qrfa,
268 gflops_gemm, dtime_gemm, k_perf );
269 fflush( stdout );
270
271 time_Svd_uv_components( -1, FLA_ALG_REFERENCE, n_repeats,
272 m, n, n_iter_max, k_accum, b_alg,
273 A, s, U, V, &dtime, &diff1, &diff2, &gflops,
274 &dtime_tred, &gflops_tred,
275 &dtime_tevd, &gflops_tevd,
276 &dtime_appq, &gflops_appq,
277 &dtime_qrfa, &gflops_qrfa,
278 &dtime_gemm, &gflops_gemm, &k_perf );
279 if ( dtime_tred == 1.0 ) gflops_tred = gflops_tevd = gflops_appq = gflops_qrfa = gflops_gemm =
280 dtime_tred = dtime_tevd = dtime_appq = dtime_qrfa = dtime_gemm = 0.0;
281
282 fprintf( stdout, "data_REFd( %2d, 1:16 ) = [ %4d %6.3lf %9.2e %6.2le %6.2le %6.3lf %9.2e %7.3lf %9.2e %6.3lf %9.2e %6.3lf %9.2e %6.3lf %9.2e %5d ];\n",
283 i, p, gflops, dtime, diff1, diff2,
284 gflops_tred, dtime_tred,
285 gflops_tevd, dtime_tevd,
286 gflops_appq, dtime_appq,
287 gflops_qrfa, dtime_qrfa,
288 gflops_gemm, dtime_gemm, k_perf );
289 fflush( stdout );
290
291
292 for ( variant = first_var; variant <= last_var; variant++ ){
293
294 fprintf( stdout, "data_var%d( %2d, 1:16 ) = [ %4d ", variant, i, p );
295 fflush( stdout );
296
297 time_Svd_uv_components( variant, FLA_ALG_UNBLOCKED, n_repeats,
298 m, n, n_iter_max, k_accum, b_alg,
299 A, s, U, V, &dtime, &diff1, &diff2, &gflops,
300 &dtime_tred, &gflops_tred,
301 &dtime_tevd, &gflops_tevd,
302 &dtime_appq, &gflops_appq,
303 &dtime_qrfa, &gflops_qrfa,
304 &dtime_gemm, &gflops_gemm, &k_perf );
305
306 fprintf( stdout, "%6.3lf %9.2e %6.2le %6.2le %6.3lf %9.2e %7.3lf %9.2e %6.3lf %9.2e %6.3lf %9.2e %6.3lf %9.2e %5d ",
307 gflops, dtime, diff1, diff2,
308 gflops_tred, dtime_tred,
309 gflops_tevd, dtime_tevd,
310 gflops_appq, dtime_appq,
311 gflops_qrfa, dtime_qrfa,
312 gflops_gemm, dtime_gemm, k_perf );
313
314 fprintf( stdout, "];\n" );
315 fflush( stdout );
316 }
317
318 fprintf( stdout, "\n" );
319
320 FLA_Obj_free( &A );
321 FLA_Obj_free( &U );
322 FLA_Obj_free( &V );
323 FLA_Obj_free( &s );
324 FLA_Obj_free( &alpha );
325 FLA_Obj_free( &shift );
326 FLA_Obj_free( &nc );
327 }
328
329 FLA_Finalize( );
330
331 return 0;
332 }
333
334