1 /*
2
3 Copyright (C) 2014, The University of Texas at Austin
4
5 This file is part of libflame and is available under the 3-Clause
6 BSD license, which can be found in the LICENSE file at the top-level
7 directory, or at http://opensource.org/licenses/BSD-3-Clause
8
9 */
10
11 #include "FLAME.h"
12
13 #define FLA_ALG_REFERENCE 0
14 #define FLA_ALG_UNBLOCKED 1
15 #define FLA_ALG_UNB_OPT 2
16 #define FLA_ALG_BLOCKED 3
17
18
19 void time_Svd_uv(
20 int variant, int type, int n_repeats, int m, int n, int n_iter_max, int k_accum, int b_alg,
21 FLA_Obj A, FLA_Obj U, FLA_Obj V, FLA_Obj s,
22 double *dtime, double *diff1, double* diff2, double *gflops, int* k_perf );
23
24
main(int argc,char * argv[])25 int main(int argc, char *argv[])
26 {
27 int
28 m_input, n_input,
29 m, n,
30 p_first, p_last, p_inc,
31 p,
32 k_accum,
33 b_alg,
34 n_iter_max,
35 min_m_n,
36 variant,
37 n_repeats,
38 i,
39 k_perf,
40 n_variants = 2;
41
42 char *colors = "brkgmcbrkg";
43 char *ticks = "o+*xso+*xs";
44 char m_dim_desc[14];
45 char m_dim_tag[10];
46 char n_dim_desc[14];
47 char n_dim_tag[10];
48
49 double max_gflops=6.0;
50
51 double
52 dtime,
53 gflops,
54 diff1,
55 diff2;
56
57 FLA_Datatype datatype, dt_real;
58
59 FLA_Obj
60 A, s, U, V, alpha;
61
62
63 FLA_Init();
64
65
66 fprintf( stdout, "%c number of repeats:", '%' );
67 scanf( "%d", &n_repeats );
68 fprintf( stdout, "%c %d\n", '%', n_repeats );
69
70 fprintf( stdout, "%c enter n_iter_max:", '%' );
71 scanf( "%d", &n_iter_max );
72 fprintf( stdout, "%c %d\n", '%', n_iter_max );
73
74 fprintf( stdout, "%c enter number of Givens rotations to accumulate:", '%' );
75 scanf( "%d", &k_accum );
76 fprintf( stdout, "%c %d\n", '%', k_accum );
77
78 fprintf( stdout, "%c enter blocking size:", '%' );
79 scanf( "%d", &b_alg );
80 fprintf( stdout, "%c %d\n", '%', b_alg );
81
82 fprintf( stdout, "%c enter problem size first, last, inc:", '%' );
83 scanf( "%d%d%d", &p_first, &p_last, &p_inc );
84 fprintf( stdout, "%c %d %d %d\n", '%', p_first, p_last, p_inc );
85
86 fprintf( stdout, "%c enter m n (-1 means bind to problem size): ", '%' );
87 scanf( "%d %d", &m_input, &n_input );
88 fprintf( stdout, "%c %d %d\n", '%', m_input, n_input );
89
90
91 fprintf( stdout, "\n" );
92
93
94 if ( m_input > 0 ) {
95 sprintf( m_dim_desc, "m = %d", m_input );
96 sprintf( m_dim_tag, "m%dc", m_input);
97 }
98 else if( m_input < -1 ) {
99 sprintf( m_dim_desc, "m = p/%d", -m_input );
100 sprintf( m_dim_tag, "m%dp", -m_input );
101 }
102 else if( m_input == -1 ) {
103 sprintf( m_dim_desc, "m = p" );
104 sprintf( m_dim_tag, "m%dp", 1 );
105 }
106 if ( n_input > 0 ) {
107 sprintf( n_dim_desc, "n = %d", n_input );
108 sprintf( n_dim_tag, "n%dc", n_input);
109 }
110 else if( n_input < -1 ) {
111 sprintf( n_dim_desc, "n = p/%d", -n_input );
112 sprintf( n_dim_tag, "n%dp", -n_input );
113 }
114 else if( n_input == -1 ) {
115 sprintf( n_dim_desc, "n = p" );
116 sprintf( n_dim_tag, "n%dp", 1 );
117 }
118
119 for ( p = p_first, i = 1; p <= p_last; p += p_inc, i += 1 )
120 {
121
122 m = m_input;
123 n = n_input;
124
125 if( m < 0 ) m = p / f2c_abs(m_input);
126 if( n < 0 ) n = p / f2c_abs(n_input);
127
128 min_m_n = min( m, n );
129
130 //datatype = FLA_FLOAT;
131 datatype = FLA_DOUBLE;
132 //datatype = FLA_COMPLEX;
133 //datatype = FLA_DOUBLE_COMPLEX;
134
135 FLA_Obj_create( datatype, m, n, 0, 0, &A );
136 FLA_Obj_create( datatype, m, m, 0, 0, &U );
137 FLA_Obj_create( datatype, n, n, 0, 0, &V );
138
139 dt_real = FLA_Obj_datatype_proj_to_real( A );
140
141 FLA_Obj_create( dt_real, min_m_n, 1, 0, 0, &s );
142 FLA_Obj_create( dt_real, 1, 1, 0, 0, &alpha );
143
144 FLA_Random_unitary_matrix( U );
145 FLA_Random_unitary_matrix( V );
146
147 FLA_Fill_with_linear_dist( FLA_ZERO, FLA_ONE, s );
148
149 //FLA_Fill_with_inverse_dist( FLA_ONE, s );
150
151 //*FLA_DOUBLE_PTR( alpha ) = 1.0 / sqrt( (double) min_m_n );
152 //FLA_Fill_with_geometric_dist( alpha, s );
153
154 {
155 FLA_Obj UL, UR;
156 FLA_Obj VL, VR;
157
158 FLA_Part_1x2( U, &UL, &UR, min_m_n, FLA_LEFT );
159 FLA_Part_1x2( V, &VL, &VR, min_m_n, FLA_LEFT );
160
161 FLA_Apply_diag_matrix( FLA_RIGHT, FLA_NO_CONJUGATE, s, UL );
162 FLA_Gemm( FLA_NO_TRANSPOSE, FLA_CONJ_TRANSPOSE,
163 FLA_ONE, UL, VL, FLA_ZERO, A );
164 }
165
166 /*
167 *FLA_DOUBLE_PTR( alpha ) = 1.0e-169;
168 FLA_Scal( alpha, A );
169 FLA_Obj_show( "A", A, "%10.2e", "" );
170 */
171
172 FLA_Set( FLA_ZERO, s );
173 FLA_Set( FLA_ZERO, U );
174 FLA_Set( FLA_ZERO, V );
175
176 time_Svd_uv( 0, FLA_ALG_REFERENCE, n_repeats, m, n, n_iter_max, k_accum, b_alg,
177 A, s, U, V, &dtime, &diff1, &diff2, &gflops, &k_perf );
178
179 fprintf( stdout, "data_REFq( %d, 1:5 ) = [ %d %6.3lf %8.2e %6.2le %6.2le %5d ]; \n", i, p, gflops, dtime, diff1, diff2, k_perf );
180 fflush( stdout );
181
182 time_Svd_uv( -1, FLA_ALG_REFERENCE, n_repeats, m, n, n_iter_max, k_accum, b_alg,
183 A, s, U, V, &dtime, &diff1, &diff2, &gflops, &k_perf );
184
185 fprintf( stdout, "data_REFd( %d, 1:5 ) = [ %d %6.3lf %8.2e %6.2le %6.2le %5d ]; \n", i, p, gflops, dtime, diff1, diff2, k_perf );
186 fflush( stdout );
187
188 for ( variant = 1; variant <= n_variants; variant++ ){
189
190 fprintf( stdout, "data_var%d( %d, 1:5 ) = [ %d ", variant, i, p );
191 fflush( stdout );
192
193 time_Svd_uv( variant, FLA_ALG_UNBLOCKED, n_repeats, m, n, n_iter_max, k_accum, b_alg,
194 A, s, U, V, &dtime, &diff1, &diff2, &gflops, &k_perf );
195
196 fprintf( stdout, "%6.3lf %8.2e %6.2le %6.2le %5d ", gflops, dtime, diff1, diff2, k_perf );
197 fflush( stdout );
198
199 fprintf( stdout, "];\n" );
200 fflush( stdout );
201 }
202
203 fprintf( stdout, "\n" );
204
205 FLA_Obj_free( &A );
206 FLA_Obj_free( &s );
207 FLA_Obj_free( &U );
208 FLA_Obj_free( &V );
209 FLA_Obj_free( &alpha );
210 }
211
212 /*
213 fprintf( stdout, "figure;\n" );
214
215 fprintf( stdout, "plot( data_REF( :,1 ), data_REF( :, 2 ), '-' ); \n" );
216
217 fprintf( stdout, "hold on;\n" );
218
219 for ( i = 1; i <= n_variants; i++ ) {
220 fprintf( stdout, "plot( data_var%d( :,1 ), data_var%d( :, 2 ), '%c:%c' ); \n",
221 i, i, colors[ i-1 ], ticks[ i-1 ] );
222 fprintf( stdout, "plot( data_var%d( :,1 ), data_var%d( :, 4 ), '%c-.%c' ); \n",
223 i, i, colors[ i-1 ], ticks[ i-1 ] );
224 }
225
226 fprintf( stdout, "legend( ... \n" );
227 fprintf( stdout, "'Reference', ... \n" );
228
229 for ( i = 1; i < n_variants; i++ )
230 fprintf( stdout, "'unb\\_var%d', 'blk\\_var%d', ... \n", i, i );
231 fprintf( stdout, "'unb\\_var%d', 'blk\\_var%d' ); \n", i, i );
232
233 fprintf( stdout, "xlabel( 'problem size p' );\n" );
234 fprintf( stdout, "ylabel( 'GFLOPS/sec.' );\n" );
235 fprintf( stdout, "axis( [ 0 %d 0 %.2f ] ); \n", p_last, max_gflops );
236 fprintf( stdout, "title( 'FLAME Svd_uv performance (%s, %s)' );\n",
237 m_dim_desc, n_dim_desc );
238 fprintf( stdout, "print -depsc tridiag_%s_%s.eps\n", m_dim_tag, n_dim_tag );
239 fprintf( stdout, "hold off;\n");
240 fflush( stdout );
241 */
242
243 FLA_Finalize( );
244
245 return 0;
246 }
247
248