1 /*
2 
3     Copyright (C) 2014, The University of Texas at Austin
4 
5     This file is part of libflame and is available under the 3-Clause
6     BSD license, which can be found in the LICENSE file at the top-level
7     directory, or at http://opensource.org/licenses/BSD-3-Clause
8 
9 */
10 
11 #include "FLAME.h"
12 
13 #define FLA_ALG_REFERENCE 0
14 #define FLA_ALG_FRONT     1
15 
16 void time_Apply_QUD_UT_inc(
17                  int n_repeats, int mB, int mC, int mD, int n, int n_rhs, dim_t b_alg,
18                  FLA_Obj R_BC, FLA_Obj R_BD, FLA_Obj C, FLA_Obj D, FLA_Obj T, FLA_Obj W,
19                  FLA_Obj bR_BC, FLA_Obj bR_BD, FLA_Obj bC, FLA_Obj bD,
20                  double *dtime, double *diff, double *gflops );
21 
22 
main(int argc,char * argv[])23 int main(int argc, char *argv[])
24 {
25   int
26     datatype,
27     n_input,
28     n_rhs_input,
29     mB_input, mC_input, mD_input,
30     n_rhs,
31     mB, mC, mD, n,
32     p_first, p_last, p_inc,
33     p,
34     variant,
35     n_repeats,
36     i,
37     n_variants = 1;
38 
39   double max_gflops=6.0;
40 
41   double
42     dtime,
43     gflops,
44     diff;
45 
46   dim_t b_alg, b_flash, n_threads;
47 
48   FLA_Obj
49     R_BD_flat, R_BC_flat, B_flat, C_flat, D_flat,
50     R_BD, R_BC, B, C, D, T, W, W2,
51     bR_BD, bR_BC, bB, bC, bD;
52 
53 
54   FLA_Init();
55 
56 
57   fprintf( stdout, "%c number of repeats:", '%' );
58   scanf( "%d", &n_repeats );
59   fprintf( stdout, "%c %d\n", '%', n_repeats );
60 
61   fprintf( stdout, "%c enter algorithmic blocksize:", '%' );
62   scanf( "%u", &b_alg );
63   fprintf( stdout, "%c %u\n", '%', b_alg );
64 
65   fprintf( stdout, "%c enter FLASH blocksize: ", '%' );
66   scanf( "%u", &b_flash );
67   fprintf( stdout, "%c %u\n", '%', b_flash );
68 
69   fprintf( stdout, "%c enter problem size first, last, inc:", '%' );
70   scanf( "%d%d%d", &p_first, &p_last, &p_inc );
71   fprintf( stdout, "%c %d %d %d\n", '%', p_first, p_last, p_inc );
72 
73   fprintf( stdout, "%c enter n (-1 means bind to problem size): ", '%' );
74   scanf( "%d", &n_input );
75   fprintf( stdout, "%c %d\n", '%', n_input );
76 
77   fprintf( stdout, "%c enter mB mC mD (-1 means bind to problem size): ", '%' );
78   scanf( "%d %d %d", &mB_input, &mC_input, &mD_input );
79   fprintf( stdout, "%c %d %d %d\n", '%', mB_input, mC_input, mD_input );
80 
81   fprintf( stdout, "%c enter n_rhs (-1 means bind to problem size): ", '%' );
82   scanf( "%d", &n_rhs_input );
83   fprintf( stdout, "%c %d\n", '%', n_rhs_input );
84 
85   fprintf( stdout, "%c enter the number of SuperMatrix threads: ", '%' );
86   scanf( "%u", &n_threads );
87   fprintf( stdout, "%c %u\n", '%', n_threads );
88 
89 
90   fprintf( stdout, "\nclear all;\n\n" );
91 
92 
93 
94   //datatype = FLA_FLOAT;
95   //datatype = FLA_DOUBLE;
96   //datatype = FLA_COMPLEX;
97   datatype = FLA_DOUBLE_COMPLEX;
98 
99   FLASH_Queue_set_num_threads( n_threads );
100   //FLASH_Queue_set_verbose_output( TRUE );
101   //FLASH_Queue_disable();
102 
103   for ( p = p_first, i = 1; p <= p_last; p += p_inc, i += 1 )
104   {
105     mB    = mB_input;
106     mC    = mC_input;
107     mD    = mD_input;
108     n     = n_input;
109     n_rhs = n_rhs_input;
110 
111     if( mB    < 0 ) mB    = p / f2c_abs(mB_input);
112     if( mC    < 0 ) mC    = p / f2c_abs(mC_input);
113     if( mD    < 0 ) mD    = p / f2c_abs(mD_input);
114     if( n     < 0 ) n     = p / f2c_abs(n_input);
115     if( n_rhs < 0 ) n_rhs = p / f2c_abs(n_rhs_input);
116 
117     for ( variant = 0; variant < n_variants; variant++ ){
118 
119       FLA_Obj_create( datatype, mB, n, 0, 0, &B_flat );
120       FLA_Obj_create( datatype, mC, n, 0, 0, &C_flat );
121       FLA_Obj_create( datatype, mD, n, 0, 0, &D_flat );
122       FLA_Obj_create( datatype, n,  n, 0, 0, &R_BC_flat );
123       FLA_Obj_create( datatype, n,  n, 0, 0, &R_BD_flat );
124 
125       FLA_Random_matrix( B_flat );
126       FLA_Random_matrix( C_flat );
127       FLA_Random_matrix( D_flat );
128 
129       FLA_Set( FLA_ZERO, R_BD_flat );
130       FLA_Herk_external( FLA_UPPER_TRIANGULAR, FLA_CONJ_TRANSPOSE,
131                          FLA_ONE, B_flat, FLA_ONE, R_BD_flat );
132       FLA_Herk_external( FLA_UPPER_TRIANGULAR, FLA_CONJ_TRANSPOSE,
133                          FLA_ONE, D_flat, FLA_ONE, R_BD_flat );
134       FLA_Chol( FLA_UPPER_TRIANGULAR, R_BD_flat );
135 
136       FLA_Set( FLA_ZERO, R_BC_flat );
137       FLA_Herk_external( FLA_UPPER_TRIANGULAR, FLA_CONJ_TRANSPOSE,
138                          FLA_ONE, B_flat, FLA_ONE, R_BC_flat );
139       FLA_Herk_external( FLA_UPPER_TRIANGULAR, FLA_CONJ_TRANSPOSE,
140                          FLA_ONE, C_flat, FLA_ONE, R_BC_flat );
141       FLA_Chol( FLA_UPPER_TRIANGULAR, R_BC_flat );
142 
143       FLASH_Obj_create_hier_copy_of_flat( B_flat, 1, &b_flash, &B );
144       FLASH_Obj_create_hier_copy_of_flat( R_BC_flat, 1, &b_flash, &R_BC );
145       FLASH_UDdate_UT_inc_create_hier_matrices( R_BD_flat, C_flat, D_flat,
146                                                 1, &b_flash, b_alg, &R_BD, &C, &D, &T, &W );
147 
148       FLASH_Obj_create( datatype, mB, n_rhs, 1, &b_flash, &bB );
149       FLASH_Obj_create( datatype, mC, n_rhs, 1, &b_flash, &bC );
150       FLASH_Obj_create( datatype, mD, n_rhs, 1, &b_flash, &bD );
151       FLASH_Obj_create( datatype, n,  n_rhs, 1, &b_flash, &bR_BC );
152       FLASH_Obj_create( datatype, n,  n_rhs, 1, &b_flash, &bR_BD );
153 
154       FLASH_Random_matrix( bB );
155       FLASH_Random_matrix( bC );
156       FLASH_Random_matrix( bD );
157 
158       FLASH_Gemm( FLA_CONJ_TRANSPOSE, FLA_NO_TRANSPOSE, FLA_ONE, B, bB, FLA_ZERO, bR_BD );
159       FLASH_Gemm( FLA_CONJ_TRANSPOSE, FLA_NO_TRANSPOSE, FLA_ONE, D, bD, FLA_ONE,  bR_BD );
160       FLASH_Trsm( FLA_LEFT, FLA_UPPER_TRIANGULAR, FLA_CONJ_TRANSPOSE, FLA_NONUNIT_DIAG, FLA_ONE, R_BD, bR_BD );
161 
162       FLASH_Gemm( FLA_CONJ_TRANSPOSE, FLA_NO_TRANSPOSE, FLA_ONE, B, bB, FLA_ZERO, bR_BC );
163       FLASH_Gemm( FLA_CONJ_TRANSPOSE, FLA_NO_TRANSPOSE, FLA_ONE, C, bC, FLA_ONE,  bR_BC );
164       FLASH_Trsm( FLA_LEFT, FLA_UPPER_TRIANGULAR, FLA_CONJ_TRANSPOSE, FLA_NONUNIT_DIAG, FLA_ONE, R_BC, bR_BC );
165 
166       FLASH_UDdate_UT_inc( R_BD, C, D, T, W );
167 
168       FLASH_Apply_QUD_UT_inc_create_workspace( T, bR_BD, &W2 );
169 
170       fprintf( stdout, "data_apqudutinc( %d, 1:5 ) = [ %d  ", i, p );
171       fflush( stdout );
172 
173       time_Apply_QUD_UT_inc( n_repeats, mB, mC, mD, n, n_rhs, b_alg,
174                              R_BC, R_BD, C, D, T, W2, bR_BC, bR_BD, bC, bD, &dtime, &diff, &gflops );
175 
176       fprintf( stdout, "%6.3lf %6.2le ", gflops, diff );
177       fflush( stdout );
178 
179       fprintf( stdout, " ]; \n" );
180       fflush( stdout );
181 
182       FLA_Obj_free( &B_flat );
183       FLA_Obj_free( &C_flat );
184       FLA_Obj_free( &D_flat );
185       FLA_Obj_free( &R_BC_flat );
186       FLA_Obj_free( &R_BD_flat );
187 
188       FLASH_Obj_free( &B );
189       FLASH_Obj_free( &C );
190       FLASH_Obj_free( &D );
191       FLASH_Obj_free( &T );
192       FLASH_Obj_free( &W );
193       FLASH_Obj_free( &W2 );
194       FLASH_Obj_free( &R_BC );
195       FLASH_Obj_free( &R_BD );
196 
197       FLASH_Obj_free( &bB );
198       FLASH_Obj_free( &bC );
199       FLASH_Obj_free( &bD );
200       FLASH_Obj_free( &bR_BC );
201       FLASH_Obj_free( &bR_BD );
202     }
203 
204     fprintf( stdout, "\n" );
205   }
206 
207 /*
208   fprintf( stdout, "figure;\n" );
209 
210   fprintf( stdout, "hold on;\n" );
211 
212   for ( i = 0; i < n_variants; i++ ) {
213     fprintf( stdout, "plot( data_apqudutinc( :,1 ), data_apqudutinc( :, 2 ), '%c:%c' ); \n",
214             colors[ i ], ticks[ i ] );
215     fprintf( stdout, "plot( data_apqudutinc( :,1 ), data_apqudutinc( :, 4 ), '%c-.%c' ); \n",
216             colors[ i ], ticks[ i ] );
217   }
218 
219   fprintf( stdout, "legend( ... \n" );
220 
221   for ( i = 0; i < n_variants; i++ )
222     fprintf( stdout, "'ref\\_qr\\_ut', 'fla\\_qr\\_ut', ... \n" );
223 
224   fprintf( stdout, "'Location', 'SouthEast' ); \n" );
225 
226   fprintf( stdout, "xlabel( 'problem size p' );\n" );
227   fprintf( stdout, "ylabel( 'GFLOPS/sec.' );\n" );
228   fprintf( stdout, "axis( [ 0 %d 0 %.2f ] ); \n", p_last, max_gflops );
229   fprintf( stdout, "title( 'FLAME Apply_QUD_UT_inc front-end performance (%s, %s)' );\n",
230            m_dim_desc, n_dim_desc );
231   fprintf( stdout, "print -depsc apqudut_front_%s_%s.eps\n", m_dim_tag, n_dim_tag );
232   fprintf( stdout, "hold off;\n");
233   fflush( stdout );
234 */
235 
236   FLA_Finalize( );
237 
238   return 0;
239 }
240 
241