1 /*
2 
3     Copyright (C) 2014, The University of Texas at Austin
4 
5     This file is part of libflame and is available under the 3-Clause
6     BSD license, which can be found in the LICENSE file at the top-level
7     directory, or at http://opensource.org/licenses/BSD-3-Clause
8 
9 */
10 
11 #include "FLAME.h"
12 #include "test_libflame.h"
13 
14 #define NUM_PARAM_COMBOS 4
15 #define NUM_MATRIX_ARGS  1
16 #define FIRST_VARIANT    1
17 #define LAST_VARIANT     4
18 
19 // Static variables.
20 static char* op_str                   = "Triangular inversion";
21 static char* flash_front_str          = "FLASH_Trinv";
22 static char* fla_front_str            = "FLA_Trinv";
23 static char* fla_unb_var_str          = "unb_var";
24 static char* fla_opt_var_str          = "opt_var";
25 static char* fla_blk_var_str          = "blk_var";
26 static char* pc_str[NUM_PARAM_COMBOS] = { "ln", "lu", "un", "uu" };
27 static test_thresh_t thresh           = { 1e-02, 1e-03,   // warn, pass for s
28                                           1e-11, 1e-12,   // warn, pass for d
29                                           1e-02, 1e-03,   // warn, pass for c
30                                           1e-11, 1e-12 }; // warn, pass for z
31 
32 static fla_trinv_t*     trinv_cntl_opt;
33 static fla_trinv_t*     trinv_cntl_unb;
34 static fla_trinv_t*     trinv_cntl_blk;
35 static fla_blocksize_t* trinv_cntl_bsize;
36 
37 // Local prototypes.
38 void libfla_test_trinv_experiment( test_params_t params,
39                                    unsigned int  var,
40                                    char*         sc_str,
41                                    FLA_Datatype  datatype,
42                                    unsigned int  p_cur,
43                                    unsigned int  pci,
44                                    unsigned int  n_repeats,
45                                    signed int    impl,
46                                    double*       perf,
47                                    double*       residual );
48 void libfla_test_trinv_impl( int         impl,
49                              FLA_Uplo    uplo,
50                              FLA_Diag    diag,
51                              FLA_Obj     A );
52 void libfla_test_trinv_cntl_create( unsigned int var,
53                                     dim_t        b_alg_flat );
54 void libfla_test_trinv_cntl_free( void );
55 
56 
libfla_test_trinv(FILE * output_stream,test_params_t params,test_op_t op)57 void libfla_test_trinv( FILE* output_stream, test_params_t params, test_op_t op )
58 {
59 	libfla_test_output_info( "--- %s ---\n", op_str );
60 	libfla_test_output_info( "\n" );
61 
62 	if ( op.flash_front == ENABLE )
63 	{
64 		libfla_test_op_driver( flash_front_str, NULL,
65 		                       FIRST_VARIANT, LAST_VARIANT,
66 		                       NUM_PARAM_COMBOS, pc_str,
67 		                       NUM_MATRIX_ARGS,
68 		                       FLA_TEST_HIER_FRONT_END,
69 		                       params, thresh, libfla_test_trinv_experiment );
70 	}
71 
72 	if ( op.fla_front == ENABLE )
73 	{
74 		libfla_test_op_driver( fla_front_str, NULL,
75 		                       FIRST_VARIANT, LAST_VARIANT,
76 		                       NUM_PARAM_COMBOS, pc_str,
77 		                       NUM_MATRIX_ARGS,
78 		                       FLA_TEST_FLAT_FRONT_END,
79 		                       params, thresh, libfla_test_trinv_experiment );
80 	}
81 
82 	if ( op.fla_unb_vars == ENABLE )
83 	{
84 		libfla_test_op_driver( fla_front_str, fla_unb_var_str,
85 		                       FIRST_VARIANT, LAST_VARIANT,
86 		                       NUM_PARAM_COMBOS, pc_str,
87 		                       NUM_MATRIX_ARGS,
88 		                       FLA_TEST_FLAT_UNB_VAR,
89 		                       params, thresh, libfla_test_trinv_experiment );
90 	}
91 
92 	if ( op.fla_opt_vars == ENABLE )
93 	{
94 		libfla_test_op_driver( fla_front_str, fla_opt_var_str,
95 		                       FIRST_VARIANT, LAST_VARIANT,
96 		                       NUM_PARAM_COMBOS, pc_str,
97 		                       NUM_MATRIX_ARGS,
98 		                       FLA_TEST_FLAT_OPT_VAR,
99 		                       params, thresh, libfla_test_trinv_experiment );
100 	}
101 
102 	if ( op.fla_blk_vars == ENABLE )
103 	{
104 		libfla_test_op_driver( fla_front_str, fla_blk_var_str,
105 		                       FIRST_VARIANT, LAST_VARIANT,
106 		                       NUM_PARAM_COMBOS, pc_str,
107 		                       NUM_MATRIX_ARGS,
108 		                       FLA_TEST_FLAT_BLK_VAR,
109 		                       params, thresh, libfla_test_trinv_experiment );
110 	}
111 }
112 
113 
114 
libfla_test_trinv_experiment(test_params_t params,unsigned int var,char * sc_str,FLA_Datatype datatype,unsigned int p_cur,unsigned int pci,unsigned int n_repeats,signed int impl,double * perf,double * residual)115 void libfla_test_trinv_experiment( test_params_t params,
116                                    unsigned int  var,
117                                    char*         sc_str,
118                                    FLA_Datatype  datatype,
119                                    unsigned int  p_cur,
120                                    unsigned int  pci,
121                                    unsigned int  n_repeats,
122                                    signed int    impl,
123                                    double*       perf,
124                                    double*       residual )
125 {
126 	dim_t        b_flash    = params.b_flash;
127 	dim_t        b_alg_flat = params.b_alg_flat;
128 	double       time_min   = 1e9;
129 	double       time;
130 	unsigned int i;
131 	unsigned int m;
132 	signed int   m_input    = -1;
133 	FLA_Uplo     uplo;
134 	FLA_Diag     diag;
135 	FLA_Obj      A, x, b, norm;
136 	FLA_Obj      A_save;
137 	FLA_Obj      A_test, x_test, b_test;
138 
139 	// Determine the dimensions.
140 	if ( m_input < 0 ) m = p_cur / abs(m_input);
141 	else               m = p_cur;
142 
143 	// Translate parameter characters to libflame constants.
144 	FLA_Param_map_char_to_flame_uplo( &pc_str[pci][0], &uplo );
145 	FLA_Param_map_char_to_flame_diag( &pc_str[pci][1], &diag );
146 
147 	// Create the matrices for the current operation.
148 	libfla_test_obj_create( datatype, FLA_NO_TRANSPOSE, sc_str[0], m, m, &A );
149 
150 	// Initialize the test matrices.
151 	FLA_Random_tri_matrix( uplo, diag, A );
152 
153 	// Save the original object contents in a temporary object.
154 	FLA_Obj_create_copy_of( FLA_NO_TRANSPOSE, A, &A_save );
155 
156 	// Create vectors to form a linear system.
157 	FLA_Obj_create( datatype, m, 1, 0, 0, &x );
158 	FLA_Obj_create( datatype, m, 1, 0, 0, &b );
159 
160 	// Create a real scalar object to hold the norm of A.
161 	FLA_Obj_create( FLA_Obj_datatype_proj_to_real( A ), 1, 1, 0, 0, &norm );
162 
163 	// Create a random right-hand side vector.
164 	FLA_Random_matrix( b );
165 
166 	// Use hierarchical matrices if we're testing the FLASH front-end.
167 	if ( impl == FLA_TEST_HIER_FRONT_END )
168 	{
169 		FLASH_Obj_create_hier_copy_of_flat( A, 1, &b_flash, &A_test );
170 		FLASH_Obj_create_hier_copy_of_flat( b, 1, &b_flash, &b_test );
171 		FLASH_Obj_create_hier_copy_of_flat( x, 1, &b_flash, &x_test );
172 	}
173 	else
174 	{
175 		A_test = A;
176 	}
177 
178 	// Create a control tree for the individual variants.
179 	if ( impl == FLA_TEST_FLAT_UNB_VAR ||
180 	     impl == FLA_TEST_FLAT_OPT_VAR ||
181 	     impl == FLA_TEST_FLAT_BLK_VAR )
182 		libfla_test_trinv_cntl_create( var, b_alg_flat );
183 
184 	// Repeat the experiment n_repeats times and record results.
185 	for ( i = 0; i < n_repeats; ++i )
186 	{
187 		if ( impl == FLA_TEST_HIER_FRONT_END )
188 			FLASH_Obj_hierarchify( A_save, A_test );
189 		else
190 			FLA_Copy_external( A_save, A_test );
191 
192 		time = FLA_Clock();
193 
194 		libfla_test_trinv_impl( impl, uplo, diag, A_test );
195 
196 		time = FLA_Clock() - time;
197 		time_min = min( time_min, time );
198 	}
199 
200 	// Perform a linear solve with the result.
201 	if ( impl == FLA_TEST_HIER_FRONT_END )
202 	{
203 		FLASH_Copy( b_test, x_test );
204 		FLASH_Trmm( FLA_LEFT, uplo, FLA_NO_TRANSPOSE, diag,
205 	                FLA_ONE, A_test, x_test );
206 		FLASH_Obj_flatten( x_test, x );
207 	}
208 	else
209     {
210 		FLA_Copy_external( b, x );
211 		FLA_Trmm_external( FLA_LEFT, uplo, FLA_NO_TRANSPOSE, diag,
212 	                       FLA_ONE, A_test, x );
213 	}
214 
215 	// Free the hierarchical matrices if we're testing the FLASH front-end.
216 	if ( impl == FLA_TEST_HIER_FRONT_END )
217 	{
218 		FLASH_Obj_free( &A_test );
219 		FLASH_Obj_free( &b_test );
220 		FLASH_Obj_free( &x_test );
221 	}
222 
223 	// Free the control trees if we're testing the variants.
224 	if ( impl == FLA_TEST_FLAT_UNB_VAR ||
225 	     impl == FLA_TEST_FLAT_OPT_VAR ||
226 	     impl == FLA_TEST_FLAT_BLK_VAR )
227 		libfla_test_trinv_cntl_free();
228 
229 	// Compute the performance of the best experiment repeat.
230 	*perf = 1.0 / 4.0 * m * m * m / time_min / FLOPS_PER_UNIT_PERF;
231 	if ( FLA_Obj_is_complex( A ) ) *perf *= 4.0;
232 
233 	// Compute the residual.
234 	FLA_Trmvsx_external( uplo, FLA_NO_TRANSPOSE, diag,
235 	                     FLA_ONE, A_save, x, FLA_MINUS_ONE, b );
236 	FLA_Nrm2_external( b, norm );
237 	FLA_Obj_extract_real_scalar( norm, residual );
238 
239 	// Free the supporting flat objects.
240 	FLA_Obj_free( &x );
241 	FLA_Obj_free( &b );
242 	FLA_Obj_free( &norm );
243 	FLA_Obj_free( &A_save );
244 
245 	// Free the flat test matrices.
246 	FLA_Obj_free( &A );
247 }
248 
249 
250 
251 extern fla_trmm_t* fla_trmm_cntl_blas;
252 extern fla_trsm_t* fla_trsm_cntl_blas;
253 extern fla_gemm_t* fla_gemm_cntl_blas;
254 
libfla_test_trinv_cntl_create(unsigned int var,dim_t b_alg_flat)255 void libfla_test_trinv_cntl_create( unsigned int var,
256                                     dim_t        b_alg_flat )
257 {
258 	int var_unb = FLA_UNB_VAR_OFFSET + var;
259 	int var_opt = FLA_OPT_VAR_OFFSET + var;
260 	int var_blk = FLA_BLK_VAR_OFFSET + var;
261 
262 	trinv_cntl_bsize = FLA_Blocksize_create( b_alg_flat, b_alg_flat, b_alg_flat, b_alg_flat );
263 
264 	trinv_cntl_unb   = FLA_Cntl_trinv_obj_create( FLA_FLAT,
265 	                                              var_unb,
266 	                                              NULL,
267 	                                              NULL,
268 	                                              NULL,
269 	                                              NULL,
270 	                                              NULL,
271 	                                              NULL );
272 
273 	trinv_cntl_opt   = FLA_Cntl_trinv_obj_create( FLA_FLAT,
274 	                                              var_opt,
275 	                                              NULL,
276 	                                              NULL,
277 	                                              NULL,
278 	                                              NULL,
279 	                                              NULL,
280 	                                              NULL );
281 
282 	trinv_cntl_blk   = FLA_Cntl_trinv_obj_create( FLA_FLAT,
283 	                                              var_blk,
284 	                                              trinv_cntl_bsize,
285 	                                              trinv_cntl_opt,
286 	                                              fla_trmm_cntl_blas,
287 	                                              fla_trsm_cntl_blas,
288 	                                              fla_trsm_cntl_blas,
289 	                                              fla_gemm_cntl_blas );
290 }
291 
292 
293 
libfla_test_trinv_cntl_free(void)294 void libfla_test_trinv_cntl_free( void )
295 {
296 	FLA_Blocksize_free( trinv_cntl_bsize );
297 
298 	FLA_Cntl_obj_free( trinv_cntl_unb );
299 	FLA_Cntl_obj_free( trinv_cntl_opt );
300 	FLA_Cntl_obj_free( trinv_cntl_blk );
301 }
302 
303 
304 
libfla_test_trinv_impl(int impl,FLA_Uplo uplo,FLA_Diag diag,FLA_Obj A)305 void libfla_test_trinv_impl( int impl,
306                              FLA_Uplo uplo,
307                              FLA_Diag diag,
308                              FLA_Obj A )
309 {
310 	switch ( impl )
311 	{
312 		case FLA_TEST_HIER_FRONT_END:
313 		FLASH_Trinv( uplo, diag, A );
314 		break;
315 
316 		case FLA_TEST_FLAT_FRONT_END:
317 		FLA_Trinv( uplo, diag, A );
318 		break;
319 
320 		case FLA_TEST_FLAT_UNB_VAR:
321 		FLA_Trinv_internal( uplo, diag, A, trinv_cntl_unb );
322 		break;
323 
324 		case FLA_TEST_FLAT_OPT_VAR:
325 		FLA_Trinv_internal( uplo, diag, A, trinv_cntl_opt );
326 		break;
327 
328 		case FLA_TEST_FLAT_BLK_VAR:
329 		FLA_Trinv_internal( uplo, diag, A, trinv_cntl_blk );
330 		break;
331 
332 		default:
333 		libfla_test_output_error( "Invalid implementation type.\n" );
334 	}
335 }
336 
337