1 /*
2 
3     Copyright (C) 2014, The University of Texas at Austin
4 
5     This file is part of libflame and is available under the 3-Clause
6     BSD license, which can be found in the LICENSE file at the top-level
7     directory, or at http://opensource.org/licenses/BSD-3-Clause
8 
9 */
10 
11 #include "FLAME.h"
12 #include "test_libflame.h"
13 
14 #define NUM_PARAM_COMBOS 1
15 #define NUM_MATRIX_ARGS  4
16 #define FIRST_VARIANT    1
17 #define LAST_VARIANT     1
18 
19 // Static variables.
20 static char* op_str                   = "Up/downdate via UD UT transform";
21 static char* fla_front_str            = "FLA_UDdate_UT";
22 static char* fla_unb_var_str          = "unb_var";
23 static char* fla_opt_var_str          = "opt_var";
24 static char* fla_blk_var_str          = "blk_var";
25 static char* pc_str[NUM_PARAM_COMBOS] = { "" };
26 static test_thresh_t thresh           = { 1e-02, 1e-03,   // warn, pass for s
27                                           1e-11, 1e-12,   // warn, pass for d
28                                           1e-02, 1e-03,   // warn, pass for c
29                                           1e-11, 1e-12 }; // warn, pass for z
30 
31 static fla_apqudut_t*   apqudut_cntl_blk;
32 static fla_uddateut_t*  uddateut_cntl_opt;
33 static fla_uddateut_t*  uddateut_cntl_unb;
34 static fla_uddateut_t*  uddateut_cntl_blk;
35 static fla_blocksize_t* uddateut_cntl_bsize;
36 
37 // Local prototypes.
38 void libfla_test_uddateut_experiment( test_params_t params,
39                                       unsigned int  var,
40                                       char*         sc_str,
41                                       FLA_Datatype  datatype,
42                                       unsigned int  p,
43                                       unsigned int  pci,
44                                       unsigned int  n_repeats,
45                                       signed int    impl,
46                                       double*       perf,
47                                       double*       residual );
48 void libfla_test_uddateut_impl( int     impl,
49                                 FLA_Obj R,
50                                 FLA_Obj C,
51                                 FLA_Obj D,
52                                 FLA_Obj T );
53 void libfla_test_uddateut_cntl_create( unsigned int var,
54                                        dim_t        b_alg_flat );
55 void libfla_test_uddateut_cntl_free( void );
56 
57 
libfla_test_uddateut(FILE * output_stream,test_params_t params,test_op_t op)58 void libfla_test_uddateut( FILE* output_stream, test_params_t params, test_op_t op )
59 {
60 	libfla_test_output_info( "--- %s ---\n", op_str );
61 	libfla_test_output_info( "\n" );
62 
63 	if ( op.fla_front == ENABLE )
64 	{
65 		libfla_test_op_driver( fla_front_str, NULL,
66 		                       FIRST_VARIANT, LAST_VARIANT,
67 		                       NUM_PARAM_COMBOS, pc_str,
68 		                       NUM_MATRIX_ARGS,
69 		                       FLA_TEST_FLAT_FRONT_END,
70 		                       params, thresh, libfla_test_uddateut_experiment );
71 	}
72 
73 	if ( op.fla_unb_vars == ENABLE )
74 	{
75 		libfla_test_op_driver( fla_front_str, fla_unb_var_str,
76 		                       FIRST_VARIANT, LAST_VARIANT,
77 		                       NUM_PARAM_COMBOS, pc_str,
78 		                       NUM_MATRIX_ARGS,
79 		                       FLA_TEST_FLAT_UNB_VAR,
80 		                       params, thresh, libfla_test_uddateut_experiment );
81 	}
82 
83 	if ( op.fla_opt_vars == ENABLE )
84 	{
85 		libfla_test_op_driver( fla_front_str, fla_opt_var_str,
86 		                       FIRST_VARIANT, LAST_VARIANT,
87 		                       NUM_PARAM_COMBOS, pc_str,
88 		                       NUM_MATRIX_ARGS,
89 		                       FLA_TEST_FLAT_OPT_VAR,
90 		                       params, thresh, libfla_test_uddateut_experiment );
91 	}
92 
93 	if ( op.fla_blk_vars == ENABLE )
94 	{
95 		libfla_test_op_driver( fla_front_str, fla_blk_var_str,
96 		                       FIRST_VARIANT, LAST_VARIANT,
97 		                       NUM_PARAM_COMBOS, pc_str,
98 		                       NUM_MATRIX_ARGS,
99 		                       FLA_TEST_FLAT_BLK_VAR,
100 		                       params, thresh, libfla_test_uddateut_experiment );
101 	}
102 
103 }
104 
105 
106 
libfla_test_uddateut_experiment(test_params_t params,unsigned int var,char * sc_str,FLA_Datatype datatype,unsigned int p_cur,unsigned int pci,unsigned int n_repeats,signed int impl,double * perf,double * residual)107 void libfla_test_uddateut_experiment( test_params_t params,
108                                       unsigned int  var,
109                                       char*         sc_str,
110                                       FLA_Datatype  datatype,
111                                       unsigned int  p_cur,
112                                       unsigned int  pci,
113                                       unsigned int  n_repeats,
114                                       signed int    impl,
115                                       double*       perf,
116                                       double*       residual )
117 {
118 	dim_t        b_alg_flat = params.b_alg_flat;
119 	double       time_min   = 1e9;
120 	double       time;
121 	unsigned int i;
122 	unsigned int mB, mC, mD, n;
123 	signed int   mB_input   = -1;
124 	signed int   mC_input   = -4;
125 	signed int   mD_input   = -4;
126 	signed int   n_input    = -1;
127 	FLA_Obj      B, C, D, T, R, RR, E, EE;
128 	FLA_Obj      R_save, C_save, D_save;
129 
130 	// Determine the dimensions.
131 	if ( mB_input < 0 ) mB = p_cur / abs(mB_input);
132 	else                mB = p_cur;
133 	if ( mC_input < 0 ) mC = p_cur / abs(mC_input);
134 	else                mC = p_cur;
135 	if ( mD_input < 0 ) mD = p_cur / abs(mD_input);
136 	else                mD = p_cur;
137 	if ( n_input  < 0 ) n  = p_cur / abs(n_input);
138 	else                n  = p_cur;
139 
140 	// Create the matrices for the current operation.
141 	libfla_test_obj_create( datatype, FLA_NO_TRANSPOSE, sc_str[0], mB, n, &B );
142 	libfla_test_obj_create( datatype, FLA_NO_TRANSPOSE, sc_str[1], mC, n, &C );
143 	libfla_test_obj_create( datatype, FLA_NO_TRANSPOSE, sc_str[2], mD, n, &D );
144 	libfla_test_obj_create( datatype, FLA_NO_TRANSPOSE, sc_str[0], n,  n, &R );
145 	libfla_test_obj_create( datatype, FLA_NO_TRANSPOSE, sc_str[0], n,  n, &E );
146 	libfla_test_obj_create( datatype, FLA_NO_TRANSPOSE, sc_str[0], n,  n, &RR );
147 	libfla_test_obj_create( datatype, FLA_NO_TRANSPOSE, sc_str[0], n,  n, &EE );
148 
149 	if ( impl == FLA_TEST_FLAT_FRONT_END ||
150 	     ( impl == FLA_TEST_FLAT_BLK_VAR && var == 1 ) )
151 		libfla_test_obj_create( datatype, FLA_NO_TRANSPOSE, sc_str[3], b_alg_flat, n, &T );
152 	else
153 		libfla_test_obj_create( datatype, FLA_NO_TRANSPOSE, sc_str[3], n, n, &T );
154 
155 	// Initialize the test matrices.
156 	FLA_Random_matrix( B );
157 	FLA_Random_matrix( C );
158 	FLA_Random_matrix( D );
159 
160 	// Intialize the test factorization.
161 	FLA_Set( FLA_ZERO, R );
162 	FLA_Herk_external( FLA_UPPER_TRIANGULAR, FLA_CONJ_TRANSPOSE, FLA_ONE, B, FLA_ONE, R );
163 	FLA_Herk_external( FLA_UPPER_TRIANGULAR, FLA_CONJ_TRANSPOSE, FLA_ONE, D, FLA_ONE, R );
164 	FLA_Chol( FLA_UPPER_TRIANGULAR, R );
165 
166 	// Initialize the solution factorization.
167 	FLA_Set( FLA_ZERO, E );
168 	FLA_Herk_external( FLA_UPPER_TRIANGULAR, FLA_CONJ_TRANSPOSE, FLA_ONE, B, FLA_ONE, E );
169 	FLA_Herk_external( FLA_UPPER_TRIANGULAR, FLA_CONJ_TRANSPOSE, FLA_ONE, C, FLA_ONE, E );
170 	FLA_Chol( FLA_UPPER_TRIANGULAR, E );
171 
172 	// Save the original test matrices to temporary objects.
173 	FLA_Obj_create_copy_of( FLA_NO_TRANSPOSE, R, &R_save );
174 	FLA_Obj_create_copy_of( FLA_NO_TRANSPOSE, C, &C_save );
175 	FLA_Obj_create_copy_of( FLA_NO_TRANSPOSE, D, &D_save );
176 
177 	// Create a control tree for the individual variants.
178 	if ( impl == FLA_TEST_FLAT_UNB_VAR ||
179 	     impl == FLA_TEST_FLAT_OPT_VAR ||
180 	     impl == FLA_TEST_FLAT_BLK_VAR )
181 		libfla_test_uddateut_cntl_create( var, b_alg_flat );
182 
183 	// Repeat the experiment n_repeats times and record results.
184 	for ( i = 0; i < n_repeats; ++i )
185 	{
186 		FLA_Copy_external( R_save, R );
187 		FLA_Copy_external( C_save, C );
188 		FLA_Copy_external( D_save, D );
189 
190 		time = FLA_Clock();
191 
192 		libfla_test_uddateut_impl( impl, R, C, D, T );
193 
194 		time = FLA_Clock() - time;
195 		time_min = min( time_min, time );
196 	}
197 
198 	// Compute R'R and E'E.
199 	FLA_Gemm_external( FLA_CONJ_TRANSPOSE, FLA_NO_TRANSPOSE, FLA_ONE, R, R, FLA_ZERO, RR );
200 	FLA_Gemm_external( FLA_CONJ_TRANSPOSE, FLA_NO_TRANSPOSE, FLA_ONE, E, E, FLA_ZERO, EE );
201 
202 	// Free the control trees if we're testing the variants.
203 	if ( impl == FLA_TEST_FLAT_UNB_VAR ||
204 	     impl == FLA_TEST_FLAT_OPT_VAR ||
205 	     impl == FLA_TEST_FLAT_BLK_VAR )
206 		libfla_test_uddateut_cntl_free();
207 
208 	// Compute the performance of the best experiment repeat.
209 	*perf = 2.0 * ( ( mC + mD ) * n * n +
210 	                ( mC + mD ) * n * 6.0 ) / time_min / FLOPS_PER_UNIT_PERF;
211 	if ( FLA_Obj_is_complex( R ) ) *perf *= 4.0;
212 
213 	// Compute the maximum element-wise difference between R'R and E'E and use
214 	// this instead of the residual.
215 	*residual = FLA_Max_elemwise_diff( RR, EE );
216 
217 	// Free the supporting flat objects.
218 	FLA_Obj_free( &R_save );
219 	FLA_Obj_free( &C_save );
220 	FLA_Obj_free( &D_save );
221 
222 	// Free the flat test matrices.
223 	FLA_Obj_free( &B );
224 	FLA_Obj_free( &C );
225 	FLA_Obj_free( &D );
226 	FLA_Obj_free( &T );
227 	FLA_Obj_free( &R );
228 	FLA_Obj_free( &RR );
229 	FLA_Obj_free( &E );
230 	FLA_Obj_free( &EE );
231 }
232 
233 
234 
235 extern fla_axpyt_t* fla_axpyt_cntl_blas;
236 extern fla_copyt_t* fla_copyt_cntl_blas;
237 extern fla_gemm_t*  fla_gemm_cntl_blas;
238 extern fla_trmm_t*  fla_trmm_cntl_blas;
239 extern fla_trsm_t*  fla_trsm_cntl_blas;
240 
libfla_test_uddateut_cntl_create(unsigned int var,dim_t b_alg_flat)241 void libfla_test_uddateut_cntl_create( unsigned int var,
242                                        dim_t        b_alg_flat )
243 {
244 	int var_unb  = FLA_UNB_VAR_OFFSET + var;
245 	int var_opt  = FLA_OPT_VAR_OFFSET + var;
246 	int var_blk  = FLA_BLK_VAR_OFFSET + var;
247 
248 	uddateut_cntl_bsize = FLA_Blocksize_create( b_alg_flat, b_alg_flat, b_alg_flat, b_alg_flat );
249 
250 	apqudut_cntl_blk  = FLA_Cntl_apqudut_obj_create( FLA_FLAT,
251 	                                                 FLA_BLOCKED_VARIANT1,
252 	                                                 uddateut_cntl_bsize,
253 	                                                 NULL,
254 	                                                 fla_gemm_cntl_blas,
255 	                                                 fla_gemm_cntl_blas,
256 	                                                 fla_gemm_cntl_blas,
257 	                                                 fla_gemm_cntl_blas,
258 	                                                 fla_trsm_cntl_blas,
259 	                                                 fla_copyt_cntl_blas,
260 	                                                 fla_axpyt_cntl_blas );
261 
262 	uddateut_cntl_unb   = FLA_Cntl_uddateut_obj_create( FLA_FLAT,
263 	                                                    var_unb,
264 	                                                    NULL,
265 	                                                    NULL,
266 	                                                    NULL );
267 
268 	uddateut_cntl_opt   = FLA_Cntl_uddateut_obj_create( FLA_FLAT,
269 	                                                    var_opt,
270 	                                                    NULL,
271 	                                                    NULL,
272 	                                                    NULL );
273 
274 	uddateut_cntl_blk   = FLA_Cntl_uddateut_obj_create( FLA_FLAT,
275 	                                                    var_blk,
276 	                                                    uddateut_cntl_bsize,
277 	                                                    uddateut_cntl_opt,
278 	                                                    apqudut_cntl_blk );
279 }
280 
281 
282 
libfla_test_uddateut_cntl_free(void)283 void libfla_test_uddateut_cntl_free( void )
284 {
285 	FLA_Blocksize_free( uddateut_cntl_bsize );
286 
287 	FLA_Cntl_obj_free( apqudut_cntl_blk );
288 	FLA_Cntl_obj_free( uddateut_cntl_unb );
289 	FLA_Cntl_obj_free( uddateut_cntl_opt );
290 	FLA_Cntl_obj_free( uddateut_cntl_blk );
291 }
292 
293 
294 
libfla_test_uddateut_impl(int impl,FLA_Obj R,FLA_Obj C,FLA_Obj D,FLA_Obj T)295 void libfla_test_uddateut_impl( int     impl,
296                                 FLA_Obj R,
297                                 FLA_Obj C,
298                                 FLA_Obj D,
299                                 FLA_Obj T )
300 {
301 	switch ( impl )
302 	{
303 		case FLA_TEST_FLAT_FRONT_END:
304 		FLA_UDdate_UT( R, C, D, T );
305 		break;
306 
307 		case FLA_TEST_FLAT_UNB_VAR:
308 		FLA_UDdate_UT_internal( R, C, D, T, uddateut_cntl_unb );
309 		break;
310 
311 		case FLA_TEST_FLAT_OPT_VAR:
312 		FLA_UDdate_UT_internal( R, C, D, T, uddateut_cntl_opt );
313 		break;
314 
315 		case FLA_TEST_FLAT_BLK_VAR:
316 		FLA_UDdate_UT_internal( R, C, D, T, uddateut_cntl_blk );
317 		break;
318 
319 		default:
320 		libfla_test_output_error( "Invalid implementation type.\n" );
321 	}
322 }
323 
324