1 /*
2 
3     Copyright (C) 2014, The University of Texas at Austin
4 
5     This file is part of libflame and is available under the 3-Clause
6     BSD license, which can be found in the LICENSE file at the top-level
7     directory, or at http://opensource.org/licenses/BSD-3-Clause
8 
9 */
10 
11 #include "FLAME.h"
12 #include "test_libflame.h"
13 
14 #define NUM_PARAM_COMBOS 4
15 #define NUM_MATRIX_ARGS  3
16 #define FIRST_VARIANT    1
17 #define LAST_VARIANT     10
18 
19 // Static variables.
20 static char* op_str                   = "Symmetric matrix-matrix multiply";
21 static char* flash_front_str          = "FLASH_Symm";
22 static char* fla_front_str            = "FLA_Symm";
23 static char* fla_unb_var_str          = "unb_var";
24 //static char* fla_opt_var_str          = "opt_var";
25 static char* fla_blk_var_str          = "blk_var";
26 static char* fla_unb_ext_str          = "external";
27 static char* pc_str[NUM_PARAM_COMBOS] = { "ll", "lu",
28                                           "rl", "ru" };
29 static test_thresh_t thresh           = { 1e-02, 1e-03,   // warn, pass for s
30                                           1e-11, 1e-12,   // warn, pass for d
31                                           1e-02, 1e-03,   // warn, pass for c
32                                           1e-11, 1e-12 }; // warn, pass for z
33 
34 static fla_symm_t*      symm_cntl_unb;
35 static fla_symm_t*      symm_cntl_blk;
36 static fla_blocksize_t* symm_cntl_bsize;
37 
38 // Local prototypes.
39 void libfla_test_symm_experiment( test_params_t params,
40                                   unsigned int  var,
41                                   char*         sc_str,
42                                   FLA_Datatype  datatype,
43                                   unsigned int  p_cur,
44                                   unsigned int  pci,
45                                   unsigned int  n_repeats,
46                                   signed int    impl,
47                                   double*       perf,
48                                   double*       residual );
49 void libfla_test_symm_impl( int         impl,
50                             FLA_Side    side,
51                             FLA_Uplo    uplo,
52                             FLA_Obj     alpha,
53                             FLA_Obj     A,
54                             FLA_Obj     B,
55                             FLA_Obj     beta,
56                             FLA_Obj     C );
57 void libfla_test_symm_cntl_create( unsigned int var,
58                                    dim_t        b_alg_flat );
59 void libfla_test_symm_cntl_free( void );
60 
61 
libfla_test_symm(FILE * output_stream,test_params_t params,test_op_t op)62 void libfla_test_symm( FILE* output_stream, test_params_t params, test_op_t op )
63 {
64 	libfla_test_output_info( "--- %s ---\n", op_str );
65 	libfla_test_output_info( "\n" );
66 
67 	if ( op.flash_front == ENABLE )
68 	{
69 		libfla_test_op_driver( flash_front_str, NULL,
70 		                       FIRST_VARIANT, LAST_VARIANT,
71 		                       NUM_PARAM_COMBOS, pc_str,
72 		                       NUM_MATRIX_ARGS,
73 		                       FLA_TEST_HIER_FRONT_END,
74 		                       params, thresh, libfla_test_symm_experiment );
75 	}
76 
77 	if ( op.fla_front == ENABLE )
78 	{
79 		libfla_test_op_driver( fla_front_str, NULL,
80 		                       FIRST_VARIANT, LAST_VARIANT,
81 		                       NUM_PARAM_COMBOS, pc_str,
82 		                       NUM_MATRIX_ARGS,
83 		                       FLA_TEST_FLAT_FRONT_END,
84 		                       params, thresh, libfla_test_symm_experiment );
85 	}
86 
87 	if ( op.fla_unb_vars == ENABLE )
88 	{
89 		libfla_test_op_driver( fla_front_str, fla_unb_var_str,
90 		                       FIRST_VARIANT, LAST_VARIANT,
91 		                       NUM_PARAM_COMBOS, pc_str,
92 		                       NUM_MATRIX_ARGS,
93 		                       FLA_TEST_FLAT_UNB_VAR,
94 		                       params, thresh, libfla_test_symm_experiment );
95 	}
96 
97 	if ( op.fla_blk_vars == ENABLE )
98 	{
99 		libfla_test_op_driver( fla_front_str, fla_blk_var_str,
100 		                       FIRST_VARIANT, LAST_VARIANT,
101 		                       NUM_PARAM_COMBOS, pc_str,
102 		                       NUM_MATRIX_ARGS,
103 		                       FLA_TEST_FLAT_BLK_VAR,
104 		                       params, thresh, libfla_test_symm_experiment );
105 	}
106 
107 	if ( op.fla_unb_ext  == ENABLE )
108 	{
109 		libfla_test_op_driver( fla_front_str, fla_unb_ext_str,
110 		                       FIRST_VARIANT, LAST_VARIANT,
111 		                       NUM_PARAM_COMBOS, pc_str,
112 		                       NUM_MATRIX_ARGS,
113 		                       FLA_TEST_FLAT_UNB_EXT,
114 		                       params, thresh, libfla_test_symm_experiment );
115 	}
116 }
117 
118 
119 
libfla_test_symm_experiment(test_params_t params,unsigned int var,char * sc_str,FLA_Datatype datatype,unsigned int p_cur,unsigned int pci,unsigned int n_repeats,signed int impl,double * perf,double * residual)120 void libfla_test_symm_experiment( test_params_t params,
121                                   unsigned int  var,
122                                   char*         sc_str,
123                                   FLA_Datatype  datatype,
124                                   unsigned int  p_cur,
125                                   unsigned int  pci,
126                                   unsigned int  n_repeats,
127                                   signed int    impl,
128                                   double*       perf,
129                                   double*       residual )
130 {
131 	dim_t        b_flash    = params.b_flash;
132 	dim_t        b_alg_flat = params.b_alg_flat;
133 	double       time_min   = 1e9;
134 	double       time;
135 	unsigned int i;
136 	unsigned int m;
137 	signed int   m_input    = -1;
138 	unsigned int n;
139 	signed int   n_input    = -1;
140 	FLA_Side     side;
141 	FLA_Uplo     uplo;
142 	FLA_Obj      A, B, C, x, y, z, w, norm;
143 	FLA_Obj      alpha, beta;
144 	FLA_Obj      C_save;
145 	FLA_Obj      A_test, B_test, C_test;
146 
147 	// Determine the dimensions.
148 	if ( m_input < 0 ) m = p_cur / abs(m_input);
149 	else               m = p_cur;
150 	if ( n_input < 0 ) n = p_cur / abs(n_input);
151 	else               n = p_cur;
152 
153 	// Translate parameter characters to libflame constants.
154 	FLA_Param_map_char_to_flame_side( &pc_str[pci][0], &side );
155 	FLA_Param_map_char_to_flame_uplo( &pc_str[pci][1], &uplo );
156 
157 	// Create the matrices for the current operation.
158 	if ( side == FLA_LEFT )
159 	{
160 		libfla_test_obj_create( datatype, FLA_NO_TRANSPOSE, sc_str[0], m, m, &A );
161 
162 		// Create vectors for use in test.
163 		FLA_Obj_create( datatype, n, 1, 0, 0, &x );
164 		FLA_Obj_create( datatype, m, 1, 0, 0, &y );
165 		FLA_Obj_create( datatype, m, 1, 0, 0, &z );
166 		FLA_Obj_create( datatype, m, 1, 0, 0, &w );
167 	}
168 	else
169 	{
170 		libfla_test_obj_create( datatype, FLA_NO_TRANSPOSE, sc_str[0], n, n, &A );
171 
172 		// Create vectors for use in test.
173 		FLA_Obj_create( datatype, n, 1, 0, 0, &x );
174 		FLA_Obj_create( datatype, m, 1, 0, 0, &y );
175 		FLA_Obj_create( datatype, m, 1, 0, 0, &z );
176 		FLA_Obj_create( datatype, n, 1, 0, 0, &w );
177 	}
178 	libfla_test_obj_create( datatype, FLA_NO_TRANSPOSE, sc_str[1], m, n, &B );
179 	libfla_test_obj_create( datatype, FLA_NO_TRANSPOSE, sc_str[2], m, n, &C );
180 
181 	// Create a norm scalar.
182 	FLA_Obj_create( FLA_Obj_datatype_proj_to_real( A ), 1, 1, 0, 0, &norm );
183 
184 	// Initialize the test matrices.
185 	FLA_Random_symm_matrix( uplo, A );
186 	FLA_Random_matrix( B );
187 	FLA_Random_matrix( C );
188 
189 	// Initialize the test vectors.
190 	FLA_Random_matrix( x );
191     FLA_Set( FLA_ZERO, y );
192     FLA_Set( FLA_ZERO, z );
193     FLA_Set( FLA_ZERO, w );
194 
195 	// Set constants.
196 	alpha = FLA_TWO;
197 	beta  = FLA_MINUS_ONE;
198 
199 	// Save the original object contents in a temporary object.
200 	FLA_Obj_create_copy_of( FLA_NO_TRANSPOSE, C, &C_save );
201 
202 	// Use hierarchical matrices if we're testing the FLASH front-end.
203 	if ( impl == FLA_TEST_HIER_FRONT_END )
204 	{
205 		FLASH_Obj_create_hier_copy_of_flat( A, 1, &b_flash, &A_test );
206 		FLASH_Obj_create_hier_copy_of_flat( B, 1, &b_flash, &B_test );
207 		FLASH_Obj_create_hier_copy_of_flat( C, 1, &b_flash, &C_test );
208 	}
209 	else
210 	{
211 		A_test = A;
212 		B_test = B;
213 		C_test = C;
214 	}
215 
216 	// Create a control tree for the individual variants.
217 	if ( impl == FLA_TEST_FLAT_UNB_VAR ||
218 	     impl == FLA_TEST_FLAT_OPT_VAR ||
219 	     impl == FLA_TEST_FLAT_BLK_VAR ||
220 	     impl == FLA_TEST_FLAT_UNB_EXT ||
221 	     impl == FLA_TEST_FLAT_BLK_EXT )
222 		libfla_test_symm_cntl_create( var, b_alg_flat );
223 
224 	// Repeat the experiment n_repeats times and record results.
225 	for ( i = 0; i < n_repeats; ++i )
226 	{
227 		if ( impl == FLA_TEST_HIER_FRONT_END )
228 			FLASH_Obj_hierarchify( C_save, C_test );
229 		else
230 			FLA_Copy_external( C_save, C_test );
231 
232 		time = FLA_Clock();
233 
234 		libfla_test_symm_impl( impl, side, uplo, alpha, A_test, B_test, beta, C_test );
235 
236 		time = FLA_Clock() - time;
237 		time_min = min( time_min, time );
238 	}
239 
240 	// Copy the solution to flat matrix X.
241 	if ( impl == FLA_TEST_HIER_FRONT_END )
242 	{
243 		FLASH_Obj_flatten( C_test, C );
244 	}
245 	else
246     {
247 		// No action needed since C_test and C refer to the same object.
248 	}
249 
250 	// Free the hierarchical matrices if we're testing the FLASH front-end.
251 	if ( impl == FLA_TEST_HIER_FRONT_END )
252 	{
253 		FLASH_Obj_free( &A_test );
254 		FLASH_Obj_free( &B_test );
255 		FLASH_Obj_free( &C_test );
256 	}
257 
258 	// Free the control trees if we're testing the variants.
259 	if ( impl == FLA_TEST_FLAT_UNB_VAR ||
260 	     impl == FLA_TEST_FLAT_OPT_VAR ||
261 	     impl == FLA_TEST_FLAT_BLK_VAR ||
262 	     impl == FLA_TEST_FLAT_UNB_EXT ||
263 	     impl == FLA_TEST_FLAT_BLK_EXT )
264 		libfla_test_symm_cntl_free();
265 
266 	// Compute the performance of the best experiment repeat.
267 	if ( side == FLA_LEFT )
268 		*perf = ( 1 * m * m * n ) / time_min / FLOPS_PER_UNIT_PERF;
269 	else
270 		*perf = ( 1 * m * n * n ) / time_min / FLOPS_PER_UNIT_PERF;
271 	if ( FLA_Obj_is_complex( A ) ) *perf *= 4.0;
272 
273 	// Compute:
274 	//   y = C * x
275 	// and compare to
276 	//   z = ( beta * C_orig + alpha * A * B ) x      (side = left)
277 	//   z = ( beta * C_orig + alpha * B * A ) x      (side = right)
278 	FLA_Gemv_external( FLA_NO_TRANSPOSE, FLA_ONE, C, x, FLA_ZERO, y );
279 
280 	if ( side == FLA_LEFT )
281 	{
282 		FLA_Gemv_external( FLA_NO_TRANSPOSE, FLA_ONE, B, x, FLA_ZERO, w );
283 		FLA_Symv_external( uplo,             alpha,   A, w, FLA_ZERO, z );
284 	}
285 	else
286 	{
287 		FLA_Symv_external( uplo,             FLA_ONE, A, x, FLA_ZERO, w );
288 		FLA_Gemv_external( FLA_NO_TRANSPOSE, alpha,   B, w, FLA_ZERO, z );
289 	}
290 	FLA_Gemv_external( FLA_NO_TRANSPOSE, beta, C_save, x, FLA_ONE, z );
291 
292 	// Compute || y - z ||.
293 	//FLA_Axpy_external( FLA_MINUS_ONE, y, z );
294 	//FLA_Nrm2_external( z, norm );
295 	//FLA_Obj_extract_real_scalar( norm, residual );
296 	*residual = FLA_Max_elemwise_diff( y, z );
297 
298 	// Free the supporting flat objects.
299 	FLA_Obj_free( &C_save );
300 
301 	// Free the flat test matrices.
302 	FLA_Obj_free( &A );
303 	FLA_Obj_free( &B );
304 	FLA_Obj_free( &C );
305 	FLA_Obj_free( &x );
306 	FLA_Obj_free( &y );
307 	FLA_Obj_free( &z );
308 	FLA_Obj_free( &w );
309 	FLA_Obj_free( &norm );
310 }
311 
312 
313 
314 extern fla_scal_t* fla_scal_cntl_blas;
315 extern fla_gemm_t* fla_gemm_cntl_blas;
316 extern fla_symm_t* fla_symm_cntl_blas;
317 
libfla_test_symm_cntl_create(unsigned int var,dim_t b_alg_flat)318 void libfla_test_symm_cntl_create( unsigned int var,
319                                    dim_t        b_alg_flat )
320 {
321 	int var_unb = FLA_UNB_VAR_OFFSET + var;
322 	int var_blk = FLA_BLK_VAR_OFFSET + var;
323 
324 	symm_cntl_bsize = FLA_Blocksize_create( b_alg_flat,
325 	                                        b_alg_flat,
326 	                                        b_alg_flat,
327 	                                        b_alg_flat );
328 
329 	symm_cntl_unb   = FLA_Cntl_symm_obj_create( FLA_FLAT,
330 	                                            var_unb,
331 	                                            NULL,
332 	                                            NULL,
333 	                                            NULL,
334 	                                            NULL,
335 	                                            NULL );
336 
337 	symm_cntl_blk   = FLA_Cntl_symm_obj_create( FLA_FLAT,
338 	                                            var_blk,
339 	                                            symm_cntl_bsize,
340 	                                            fla_scal_cntl_blas,
341 	                                            fla_symm_cntl_blas,
342 	                                            fla_gemm_cntl_blas,
343 	                                            fla_gemm_cntl_blas );
344 }
345 
346 
347 
libfla_test_symm_cntl_free(void)348 void libfla_test_symm_cntl_free( void )
349 {
350 	FLA_Blocksize_free( symm_cntl_bsize );
351 
352 	FLA_Cntl_obj_free( symm_cntl_unb );
353 	FLA_Cntl_obj_free( symm_cntl_blk );
354 }
355 
356 
357 
libfla_test_symm_impl(int impl,FLA_Side side,FLA_Uplo uplo,FLA_Obj alpha,FLA_Obj A,FLA_Obj B,FLA_Obj beta,FLA_Obj C)358 void libfla_test_symm_impl( int       impl,
359                             FLA_Side  side,
360                             FLA_Uplo  uplo,
361                             FLA_Obj   alpha,
362                             FLA_Obj   A,
363                             FLA_Obj   B,
364                             FLA_Obj   beta,
365                             FLA_Obj   C )
366 {
367 	switch ( impl )
368 	{
369 		case FLA_TEST_HIER_FRONT_END:
370 		FLASH_Symm( side, uplo, alpha, A, B, beta, C );
371 		break;
372 
373 		case FLA_TEST_FLAT_FRONT_END:
374 		FLA_Symm( side, uplo, alpha, A, B, beta, C );
375 		break;
376 
377 		case FLA_TEST_FLAT_UNB_VAR:
378 		FLA_Symm_internal( side, uplo, alpha, A, B, beta, C, symm_cntl_unb );
379 		break;
380 
381 		case FLA_TEST_FLAT_BLK_VAR:
382 		FLA_Symm_internal( side, uplo, alpha, A, B, beta, C, symm_cntl_blk );
383 		break;
384 
385 		case FLA_TEST_FLAT_UNB_EXT:
386 		FLA_Symm_external( side, uplo, alpha, A, B, beta, C );
387 		break;
388 
389 		default:
390 		libfla_test_output_error( "Invalid implementation type.\n" );
391 	}
392 }
393 
394