1 /*
2 
3     Copyright (C) 2014, The University of Texas at Austin
4 
5     This file is part of libflame and is available under the 3-Clause
6     BSD license, which can be found in the LICENSE file at the top-level
7     directory, or at http://opensource.org/licenses/BSD-3-Clause
8 
9 */
10 
11 #include "FLAME.h"
12 #include "test_libflame.h"
13 
14 #define NUM_PARAM_COMBOS 32
15 #define NUM_MATRIX_ARGS  2
16 #define FIRST_VARIANT    1
17 #define LAST_VARIANT     4
18 
19 // Static variables.
20 static char* op_str                   = "Triangular solve with multiple rhs";
21 static char* flash_front_str          = "FLASH_Trsm";
22 static char* fla_front_str            = "FLA_Trsm";
23 static char* fla_unb_var_str          = "unb_var";
24 //static char* fla_opt_var_str          = "opt_var";
25 static char* fla_blk_var_str          = "blk_var";
26 static char* fla_unb_ext_str          = "external";
27 static char* pc_str[NUM_PARAM_COMBOS] = { "llnn", "llnu",
28                                           "llcn", "llcu",
29                                           "lltn", "lltu",
30                                           "llhn", "llhu",
31                                           "lunn", "lunu",
32                                           "lucn", "lucu",
33                                           "lutn", "lutu",
34                                           "luhn", "luhu",
35                                           "rlnn", "rlnu",
36                                           "rlcn", "rlcu",
37                                           "rltn", "rltu",
38                                           "rlhn", "rlhu",
39                                           "runn", "runu",
40                                           "rucn", "rucu",
41                                           "rutn", "rutu",
42                                           "ruhn", "ruhu" };
43 static test_thresh_t thresh           = { 1e-02, 1e-03,   // warn, pass for s
44                                           1e-11, 1e-12,   // warn, pass for d
45                                           1e-02, 1e-03,   // warn, pass for c
46                                           1e-11, 1e-12 }; // warn, pass for z
47 
48 static fla_trsm_t*      trsm_cntl_unb;
49 static fla_trsm_t*      trsm_cntl_blk;
50 static fla_blocksize_t* trsm_cntl_bsize;
51 
52 // Local prototypes.
53 void libfla_test_trsm_experiment( test_params_t params,
54                                   unsigned int  var,
55                                   char*         sc_str,
56                                   FLA_Datatype  datatype,
57                                   unsigned int  p_cur,
58                                   unsigned int  pci,
59                                   unsigned int  n_repeats,
60                                   signed int    impl,
61                                   double*       perf,
62                                   double*       residual );
63 void libfla_test_trsm_impl( int         impl,
64                             FLA_Side    side,
65                             FLA_Uplo    uplo,
66                             FLA_Trans   trans,
67                             FLA_Diag    diag,
68                             FLA_Obj     alpha,
69                             FLA_Obj     A,
70                             FLA_Obj     B );
71 void libfla_test_trsm_cntl_create( unsigned int var,
72                                    dim_t        b_alg_flat );
73 void libfla_test_trsm_cntl_free( void );
74 
75 
libfla_test_trsm(FILE * output_stream,test_params_t params,test_op_t op)76 void libfla_test_trsm( FILE* output_stream, test_params_t params, test_op_t op )
77 {
78 	libfla_test_output_info( "--- %s ---\n", op_str );
79 	libfla_test_output_info( "\n" );
80 
81 	if ( op.flash_front == ENABLE )
82 	{
83 		libfla_test_op_driver( flash_front_str, NULL,
84 		                       FIRST_VARIANT, LAST_VARIANT,
85 		                       NUM_PARAM_COMBOS, pc_str,
86 		                       NUM_MATRIX_ARGS,
87 		                       FLA_TEST_HIER_FRONT_END,
88 		                       params, thresh, libfla_test_trsm_experiment );
89 	}
90 
91 	if ( op.fla_front == ENABLE )
92 	{
93 		libfla_test_op_driver( fla_front_str, NULL,
94 		                       FIRST_VARIANT, LAST_VARIANT,
95 		                       NUM_PARAM_COMBOS, pc_str,
96 		                       NUM_MATRIX_ARGS,
97 		                       FLA_TEST_FLAT_FRONT_END,
98 		                       params, thresh, libfla_test_trsm_experiment );
99 	}
100 
101 	if ( op.fla_unb_vars == ENABLE )
102 	{
103 		libfla_test_op_driver( fla_front_str, fla_unb_var_str,
104 		                       FIRST_VARIANT, LAST_VARIANT,
105 		                       NUM_PARAM_COMBOS, pc_str,
106 		                       NUM_MATRIX_ARGS,
107 		                       FLA_TEST_FLAT_UNB_VAR,
108 		                       params, thresh, libfla_test_trsm_experiment );
109 	}
110 
111 	if ( op.fla_blk_vars == ENABLE )
112 	{
113 		libfla_test_op_driver( fla_front_str, fla_blk_var_str,
114 		                       FIRST_VARIANT, LAST_VARIANT,
115 		                       NUM_PARAM_COMBOS, pc_str,
116 		                       NUM_MATRIX_ARGS,
117 		                       FLA_TEST_FLAT_BLK_VAR,
118 		                       params, thresh, libfla_test_trsm_experiment );
119 	}
120 
121 	if ( op.fla_unb_ext  == ENABLE )
122 	{
123 		libfla_test_op_driver( fla_front_str, fla_unb_ext_str,
124 		                       FIRST_VARIANT, LAST_VARIANT,
125 		                       NUM_PARAM_COMBOS, pc_str,
126 		                       NUM_MATRIX_ARGS,
127 		                       FLA_TEST_FLAT_UNB_EXT,
128 		                       params, thresh, libfla_test_trsm_experiment );
129 	}
130 }
131 
132 
133 
libfla_test_trsm_experiment(test_params_t params,unsigned int var,char * sc_str,FLA_Datatype datatype,unsigned int p_cur,unsigned int pci,unsigned int n_repeats,signed int impl,double * perf,double * residual)134 void libfla_test_trsm_experiment( test_params_t params,
135                                   unsigned int  var,
136                                   char*         sc_str,
137                                   FLA_Datatype  datatype,
138                                   unsigned int  p_cur,
139                                   unsigned int  pci,
140                                   unsigned int  n_repeats,
141                                   signed int    impl,
142                                   double*       perf,
143                                   double*       residual )
144 {
145 	dim_t        b_flash    = params.b_flash;
146 	dim_t        b_alg_flat = params.b_alg_flat;
147 	double       time_min   = 1e9;
148 	double       time;
149 	unsigned int i;
150 	unsigned int m;
151 	signed int   m_input    = -1;
152 	unsigned int n;
153 	signed int   n_input    = -1;
154 	FLA_Side     side;
155 	FLA_Uplo     uplo;
156 	FLA_Trans    trans;
157 	FLA_Diag     diag;
158 	FLA_Obj      A, B, x, y, z, norm;
159 	FLA_Obj      alpha;
160 	FLA_Obj      B_save;
161 	FLA_Obj      A_test, B_test;
162 
163 	// Determine the dimensions.
164 	if ( m_input < 0 ) m = p_cur / abs(m_input);
165 	else               m = p_cur;
166 	if ( n_input < 0 ) n = p_cur / abs(n_input);
167 	else               n = p_cur;
168 
169 	// Translate parameter characters to libflame constants.
170 	FLA_Param_map_char_to_flame_side( &pc_str[pci][0], &side );
171 	FLA_Param_map_char_to_flame_uplo( &pc_str[pci][1], &uplo );
172 	FLA_Param_map_char_to_flame_trans( &pc_str[pci][2], &trans );
173 	FLA_Param_map_char_to_flame_diag( &pc_str[pci][3], &diag );
174 
175 	// Create the matrices for the current operation.
176 	if ( side == FLA_LEFT )
177 		libfla_test_obj_create( datatype, FLA_NO_TRANSPOSE, sc_str[0], m, m, &A );
178 	else
179 		libfla_test_obj_create( datatype, FLA_NO_TRANSPOSE, sc_str[0], n, n, &A );
180 	libfla_test_obj_create( datatype, FLA_NO_TRANSPOSE, sc_str[1], m, n, &B );
181 
182 	// Create vectors for use in test.
183 	FLA_Obj_create( datatype, n, 1, 0, 0, &x );
184 	FLA_Obj_create( datatype, m, 1, 0, 0, &y );
185 	FLA_Obj_create( datatype, m, 1, 0, 0, &z );
186 
187 	// Create a norm scalar.
188 	FLA_Obj_create( FLA_Obj_datatype_proj_to_real( A ), 1, 1, 0, 0, &norm );
189 
190 	// Initialize the test matrices.
191 	FLA_Random_tri_matrix( uplo, diag, A );
192 	FLA_Random_matrix( B );
193 
194 	// Initialize the test vectors.
195 	FLA_Random_matrix( x );
196     FLA_Set( FLA_ZERO, y );
197     FLA_Set( FLA_ZERO, z );
198 
199 	// Set constants.
200 	alpha = FLA_TWO;
201 
202 	// Save the original object contents in a temporary object.
203 	FLA_Obj_create_copy_of( FLA_NO_TRANSPOSE, B, &B_save );
204 
205 	// Use hierarchical matrices if we're testing the FLASH front-end.
206 	if ( impl == FLA_TEST_HIER_FRONT_END )
207 	{
208 		FLASH_Obj_create_hier_copy_of_flat( A, 1, &b_flash, &A_test );
209 		FLASH_Obj_create_hier_copy_of_flat( B, 1, &b_flash, &B_test );
210 	}
211 	else
212 	{
213 		A_test = A;
214 		B_test = B;
215 	}
216 
217 	// Create a control tree for the individual variants.
218 	if ( impl == FLA_TEST_FLAT_UNB_VAR ||
219 	     impl == FLA_TEST_FLAT_OPT_VAR ||
220 	     impl == FLA_TEST_FLAT_BLK_VAR ||
221 	     impl == FLA_TEST_FLAT_UNB_EXT ||
222 	     impl == FLA_TEST_FLAT_BLK_EXT )
223 		libfla_test_trsm_cntl_create( var, b_alg_flat );
224 
225 	// Repeat the experiment n_repeats times and record results.
226 	for ( i = 0; i < n_repeats; ++i )
227 	{
228 		if ( impl == FLA_TEST_HIER_FRONT_END )
229 			FLASH_Obj_hierarchify( B_save, B_test );
230 		else
231 			FLA_Copy_external( B_save, B_test );
232 
233 		time = FLA_Clock();
234 
235 		libfla_test_trsm_impl( impl, side, uplo, trans, diag, alpha, A_test, B_test );
236 
237 		time = FLA_Clock() - time;
238 		time_min = min( time_min, time );
239 	}
240 
241 	// Copy the solution to flat matrix X.
242 	if ( impl == FLA_TEST_HIER_FRONT_END )
243 	{
244 		FLASH_Obj_flatten( B_test, B );
245 	}
246 	else
247     {
248 		// No action needed since B_test and B refer to the same object.
249 	}
250 
251 	// Free the hierarchical matrices if we're testing the FLASH front-end.
252 	if ( impl == FLA_TEST_HIER_FRONT_END )
253 	{
254 		FLASH_Obj_free( &A_test );
255 		FLASH_Obj_free( &B_test );
256 	}
257 
258 	// Free the control trees if we're testing the variants.
259 	if ( impl == FLA_TEST_FLAT_UNB_VAR ||
260 	     impl == FLA_TEST_FLAT_OPT_VAR ||
261 	     impl == FLA_TEST_FLAT_BLK_VAR ||
262 	     impl == FLA_TEST_FLAT_UNB_EXT ||
263 	     impl == FLA_TEST_FLAT_BLK_EXT )
264 		libfla_test_trsm_cntl_free();
265 
266 	// Compute the performance of the best experiment repeat.
267 	if ( side == FLA_LEFT )
268 		*perf = ( 1 * m * m * n ) / time_min / FLOPS_PER_UNIT_PERF;
269 	else
270 		*perf = ( 1 * m * n * n ) / time_min / FLOPS_PER_UNIT_PERF;
271 	if ( FLA_Obj_is_complex( A ) ) *perf *= 4.0;
272 
273 	// Compute:
274 	//   y = B * x
275 	// and compare to
276 	//   z = ( alpha * inv(A) * B ) x      (side = left)
277 	//   z = ( alpha * B * inv(A) ) x      (side = right)
278 	FLA_Gemv_external( FLA_NO_TRANSPOSE, FLA_ONE, B, x, FLA_ZERO, y );
279 
280 	if ( side == FLA_LEFT )
281 	{
282 		FLA_Gemv_external( FLA_NO_TRANSPOSE, alpha, B_save, x, FLA_ZERO, z );
283 		FLA_Trsv_external( uplo, trans, diag, A, z );
284 	}
285 	else
286 	{
287 		FLA_Trsv_external( uplo, trans, diag, A, x );
288 		FLA_Gemv_external( FLA_NO_TRANSPOSE, alpha, B_save, x, FLA_ZERO, z );
289 	}
290 
291 	// Compute || y - z ||.
292 	//FLA_Axpy_external( FLA_MINUS_ONE, y, z );
293 	//FLA_Nrm2_external( z, norm );
294 	//FLA_Obj_extract_real_scalar( norm, residual );
295 	*residual = FLA_Max_elemwise_diff( y, z );
296 
297 	// Free the supporting flat objects.
298 	FLA_Obj_free( &B_save );
299 
300 	// Free the flat test matrices.
301 	FLA_Obj_free( &A );
302 	FLA_Obj_free( &B );
303 	FLA_Obj_free( &x );
304 	FLA_Obj_free( &y );
305 	FLA_Obj_free( &z );
306 	FLA_Obj_free( &norm );
307 }
308 
309 
310 
311 extern fla_scal_t* fla_scal_cntl_blas;
312 extern fla_gemm_t* fla_gemm_cntl_blas;
313 extern fla_trsm_t* fla_trsm_cntl_blas;
314 
libfla_test_trsm_cntl_create(unsigned int var,dim_t b_alg_flat)315 void libfla_test_trsm_cntl_create( unsigned int var,
316                                    dim_t        b_alg_flat )
317 {
318 	int var_unb = FLA_UNB_VAR_OFFSET + var;
319 	int var_blk = FLA_BLK_VAR_OFFSET + var;
320 
321 	trsm_cntl_bsize = FLA_Blocksize_create( b_alg_flat,
322 	                                        b_alg_flat,
323 	                                        b_alg_flat,
324 	                                        b_alg_flat );
325 
326 	trsm_cntl_unb   = FLA_Cntl_trsm_obj_create( FLA_FLAT,
327 	                                            var_unb,
328 	                                            NULL,
329 	                                            NULL,
330 	                                            NULL,
331 	                                            NULL );
332 
333 	trsm_cntl_blk   = FLA_Cntl_trsm_obj_create( FLA_FLAT,
334 	                                            var_blk,
335 	                                            trsm_cntl_bsize,
336 	                                            fla_scal_cntl_blas,
337 	                                            fla_trsm_cntl_blas,
338 	                                            fla_gemm_cntl_blas );
339 }
340 
341 
342 
libfla_test_trsm_cntl_free(void)343 void libfla_test_trsm_cntl_free( void )
344 {
345 	FLA_Blocksize_free( trsm_cntl_bsize );
346 
347 	FLA_Cntl_obj_free( trsm_cntl_unb );
348 	FLA_Cntl_obj_free( trsm_cntl_blk );
349 }
350 
351 
352 
libfla_test_trsm_impl(int impl,FLA_Side side,FLA_Uplo uplo,FLA_Trans trans,FLA_Diag diag,FLA_Obj alpha,FLA_Obj A,FLA_Obj B)353 void libfla_test_trsm_impl( int       impl,
354                             FLA_Side  side,
355                             FLA_Uplo  uplo,
356                             FLA_Trans trans,
357                             FLA_Diag  diag,
358                             FLA_Obj   alpha,
359                             FLA_Obj   A,
360                             FLA_Obj   B )
361 {
362 	switch ( impl )
363 	{
364 		case FLA_TEST_HIER_FRONT_END:
365 		FLASH_Trsm( side, uplo, trans, diag, alpha, A, B );
366 		break;
367 
368 		case FLA_TEST_FLAT_FRONT_END:
369 		FLA_Trsm( side, uplo, trans, diag, alpha, A, B );
370 		break;
371 
372 		case FLA_TEST_FLAT_UNB_VAR:
373 		FLA_Trsm_internal( side, uplo, trans, diag, alpha, A, B, trsm_cntl_unb );
374 		break;
375 
376 		case FLA_TEST_FLAT_BLK_VAR:
377 		FLA_Trsm_internal( side, uplo, trans, diag, alpha, A, B, trsm_cntl_blk );
378 		break;
379 
380 		case FLA_TEST_FLAT_UNB_EXT:
381 		FLA_Trsm_external( side, uplo, trans, diag, alpha, A, B );
382 		break;
383 
384 		default:
385 		libfla_test_output_error( "Invalid implementation type.\n" );
386 	}
387 }
388 
389