1 /*
2
3 Copyright (C) 2014, The University of Texas at Austin
4
5 This file is part of libflame and is available under the 3-Clause
6 BSD license, which can be found in the LICENSE file at the top-level
7 directory, or at http://opensource.org/licenses/BSD-3-Clause
8
9 */
10
11 #include "FLAME.h"
12 #include "test_libflame.h"
13
14 #define NUM_PARAM_COMBOS 4
15 #define NUM_MATRIX_ARGS 1
16 #define FIRST_VARIANT 1
17 #define LAST_VARIANT 4
18
19 // Static variables.
20 static char* op_str = "Triangular inversion";
21 static char* flash_front_str = "FLASH_Trinv";
22 static char* fla_front_str = "FLA_Trinv";
23 static char* fla_unb_var_str = "unb_var";
24 static char* fla_opt_var_str = "opt_var";
25 static char* fla_blk_var_str = "blk_var";
26 static char* pc_str[NUM_PARAM_COMBOS] = { "ln", "lu", "un", "uu" };
27 static test_thresh_t thresh = { 1e-02, 1e-03, // warn, pass for s
28 1e-11, 1e-12, // warn, pass for d
29 1e-02, 1e-03, // warn, pass for c
30 1e-11, 1e-12 }; // warn, pass for z
31
32 static fla_trinv_t* trinv_cntl_opt;
33 static fla_trinv_t* trinv_cntl_unb;
34 static fla_trinv_t* trinv_cntl_blk;
35 static fla_blocksize_t* trinv_cntl_bsize;
36
37 // Local prototypes.
38 void libfla_test_trinv_experiment( test_params_t params,
39 unsigned int var,
40 char* sc_str,
41 FLA_Datatype datatype,
42 unsigned int p_cur,
43 unsigned int pci,
44 unsigned int n_repeats,
45 signed int impl,
46 double* perf,
47 double* residual );
48 void libfla_test_trinv_impl( int impl,
49 FLA_Uplo uplo,
50 FLA_Diag diag,
51 FLA_Obj A );
52 void libfla_test_trinv_cntl_create( unsigned int var,
53 dim_t b_alg_flat );
54 void libfla_test_trinv_cntl_free( void );
55
56
libfla_test_trinv(FILE * output_stream,test_params_t params,test_op_t op)57 void libfla_test_trinv( FILE* output_stream, test_params_t params, test_op_t op )
58 {
59 libfla_test_output_info( "--- %s ---\n", op_str );
60 libfla_test_output_info( "\n" );
61
62 if ( op.flash_front == ENABLE )
63 {
64 libfla_test_op_driver( flash_front_str, NULL,
65 FIRST_VARIANT, LAST_VARIANT,
66 NUM_PARAM_COMBOS, pc_str,
67 NUM_MATRIX_ARGS,
68 FLA_TEST_HIER_FRONT_END,
69 params, thresh, libfla_test_trinv_experiment );
70 }
71
72 if ( op.fla_front == ENABLE )
73 {
74 libfla_test_op_driver( fla_front_str, NULL,
75 FIRST_VARIANT, LAST_VARIANT,
76 NUM_PARAM_COMBOS, pc_str,
77 NUM_MATRIX_ARGS,
78 FLA_TEST_FLAT_FRONT_END,
79 params, thresh, libfla_test_trinv_experiment );
80 }
81
82 if ( op.fla_unb_vars == ENABLE )
83 {
84 libfla_test_op_driver( fla_front_str, fla_unb_var_str,
85 FIRST_VARIANT, LAST_VARIANT,
86 NUM_PARAM_COMBOS, pc_str,
87 NUM_MATRIX_ARGS,
88 FLA_TEST_FLAT_UNB_VAR,
89 params, thresh, libfla_test_trinv_experiment );
90 }
91
92 if ( op.fla_opt_vars == ENABLE )
93 {
94 libfla_test_op_driver( fla_front_str, fla_opt_var_str,
95 FIRST_VARIANT, LAST_VARIANT,
96 NUM_PARAM_COMBOS, pc_str,
97 NUM_MATRIX_ARGS,
98 FLA_TEST_FLAT_OPT_VAR,
99 params, thresh, libfla_test_trinv_experiment );
100 }
101
102 if ( op.fla_blk_vars == ENABLE )
103 {
104 libfla_test_op_driver( fla_front_str, fla_blk_var_str,
105 FIRST_VARIANT, LAST_VARIANT,
106 NUM_PARAM_COMBOS, pc_str,
107 NUM_MATRIX_ARGS,
108 FLA_TEST_FLAT_BLK_VAR,
109 params, thresh, libfla_test_trinv_experiment );
110 }
111 }
112
113
114
libfla_test_trinv_experiment(test_params_t params,unsigned int var,char * sc_str,FLA_Datatype datatype,unsigned int p_cur,unsigned int pci,unsigned int n_repeats,signed int impl,double * perf,double * residual)115 void libfla_test_trinv_experiment( test_params_t params,
116 unsigned int var,
117 char* sc_str,
118 FLA_Datatype datatype,
119 unsigned int p_cur,
120 unsigned int pci,
121 unsigned int n_repeats,
122 signed int impl,
123 double* perf,
124 double* residual )
125 {
126 dim_t b_flash = params.b_flash;
127 dim_t b_alg_flat = params.b_alg_flat;
128 double time_min = 1e9;
129 double time;
130 unsigned int i;
131 unsigned int m;
132 signed int m_input = -1;
133 FLA_Uplo uplo;
134 FLA_Diag diag;
135 FLA_Obj A, x, b, norm;
136 FLA_Obj A_save;
137 FLA_Obj A_test, x_test, b_test;
138
139 // Determine the dimensions.
140 if ( m_input < 0 ) m = p_cur / abs(m_input);
141 else m = p_cur;
142
143 // Translate parameter characters to libflame constants.
144 FLA_Param_map_char_to_flame_uplo( &pc_str[pci][0], &uplo );
145 FLA_Param_map_char_to_flame_diag( &pc_str[pci][1], &diag );
146
147 // Create the matrices for the current operation.
148 libfla_test_obj_create( datatype, FLA_NO_TRANSPOSE, sc_str[0], m, m, &A );
149
150 // Initialize the test matrices.
151 FLA_Random_tri_matrix( uplo, diag, A );
152
153 // Save the original object contents in a temporary object.
154 FLA_Obj_create_copy_of( FLA_NO_TRANSPOSE, A, &A_save );
155
156 // Create vectors to form a linear system.
157 FLA_Obj_create( datatype, m, 1, 0, 0, &x );
158 FLA_Obj_create( datatype, m, 1, 0, 0, &b );
159
160 // Create a real scalar object to hold the norm of A.
161 FLA_Obj_create( FLA_Obj_datatype_proj_to_real( A ), 1, 1, 0, 0, &norm );
162
163 // Create a random right-hand side vector.
164 FLA_Random_matrix( b );
165
166 // Use hierarchical matrices if we're testing the FLASH front-end.
167 if ( impl == FLA_TEST_HIER_FRONT_END )
168 {
169 FLASH_Obj_create_hier_copy_of_flat( A, 1, &b_flash, &A_test );
170 FLASH_Obj_create_hier_copy_of_flat( b, 1, &b_flash, &b_test );
171 FLASH_Obj_create_hier_copy_of_flat( x, 1, &b_flash, &x_test );
172 }
173 else
174 {
175 A_test = A;
176 }
177
178 // Create a control tree for the individual variants.
179 if ( impl == FLA_TEST_FLAT_UNB_VAR ||
180 impl == FLA_TEST_FLAT_OPT_VAR ||
181 impl == FLA_TEST_FLAT_BLK_VAR )
182 libfla_test_trinv_cntl_create( var, b_alg_flat );
183
184 // Repeat the experiment n_repeats times and record results.
185 for ( i = 0; i < n_repeats; ++i )
186 {
187 if ( impl == FLA_TEST_HIER_FRONT_END )
188 FLASH_Obj_hierarchify( A_save, A_test );
189 else
190 FLA_Copy_external( A_save, A_test );
191
192 time = FLA_Clock();
193
194 libfla_test_trinv_impl( impl, uplo, diag, A_test );
195
196 time = FLA_Clock() - time;
197 time_min = min( time_min, time );
198 }
199
200 // Perform a linear solve with the result.
201 if ( impl == FLA_TEST_HIER_FRONT_END )
202 {
203 FLASH_Copy( b_test, x_test );
204 FLASH_Trmm( FLA_LEFT, uplo, FLA_NO_TRANSPOSE, diag,
205 FLA_ONE, A_test, x_test );
206 FLASH_Obj_flatten( x_test, x );
207 }
208 else
209 {
210 FLA_Copy_external( b, x );
211 FLA_Trmm_external( FLA_LEFT, uplo, FLA_NO_TRANSPOSE, diag,
212 FLA_ONE, A_test, x );
213 }
214
215 // Free the hierarchical matrices if we're testing the FLASH front-end.
216 if ( impl == FLA_TEST_HIER_FRONT_END )
217 {
218 FLASH_Obj_free( &A_test );
219 FLASH_Obj_free( &b_test );
220 FLASH_Obj_free( &x_test );
221 }
222
223 // Free the control trees if we're testing the variants.
224 if ( impl == FLA_TEST_FLAT_UNB_VAR ||
225 impl == FLA_TEST_FLAT_OPT_VAR ||
226 impl == FLA_TEST_FLAT_BLK_VAR )
227 libfla_test_trinv_cntl_free();
228
229 // Compute the performance of the best experiment repeat.
230 *perf = 1.0 / 4.0 * m * m * m / time_min / FLOPS_PER_UNIT_PERF;
231 if ( FLA_Obj_is_complex( A ) ) *perf *= 4.0;
232
233 // Compute the residual.
234 FLA_Trmvsx_external( uplo, FLA_NO_TRANSPOSE, diag,
235 FLA_ONE, A_save, x, FLA_MINUS_ONE, b );
236 FLA_Nrm2_external( b, norm );
237 FLA_Obj_extract_real_scalar( norm, residual );
238
239 // Free the supporting flat objects.
240 FLA_Obj_free( &x );
241 FLA_Obj_free( &b );
242 FLA_Obj_free( &norm );
243 FLA_Obj_free( &A_save );
244
245 // Free the flat test matrices.
246 FLA_Obj_free( &A );
247 }
248
249
250
251 extern fla_trmm_t* fla_trmm_cntl_blas;
252 extern fla_trsm_t* fla_trsm_cntl_blas;
253 extern fla_gemm_t* fla_gemm_cntl_blas;
254
libfla_test_trinv_cntl_create(unsigned int var,dim_t b_alg_flat)255 void libfla_test_trinv_cntl_create( unsigned int var,
256 dim_t b_alg_flat )
257 {
258 int var_unb = FLA_UNB_VAR_OFFSET + var;
259 int var_opt = FLA_OPT_VAR_OFFSET + var;
260 int var_blk = FLA_BLK_VAR_OFFSET + var;
261
262 trinv_cntl_bsize = FLA_Blocksize_create( b_alg_flat, b_alg_flat, b_alg_flat, b_alg_flat );
263
264 trinv_cntl_unb = FLA_Cntl_trinv_obj_create( FLA_FLAT,
265 var_unb,
266 NULL,
267 NULL,
268 NULL,
269 NULL,
270 NULL,
271 NULL );
272
273 trinv_cntl_opt = FLA_Cntl_trinv_obj_create( FLA_FLAT,
274 var_opt,
275 NULL,
276 NULL,
277 NULL,
278 NULL,
279 NULL,
280 NULL );
281
282 trinv_cntl_blk = FLA_Cntl_trinv_obj_create( FLA_FLAT,
283 var_blk,
284 trinv_cntl_bsize,
285 trinv_cntl_opt,
286 fla_trmm_cntl_blas,
287 fla_trsm_cntl_blas,
288 fla_trsm_cntl_blas,
289 fla_gemm_cntl_blas );
290 }
291
292
293
libfla_test_trinv_cntl_free(void)294 void libfla_test_trinv_cntl_free( void )
295 {
296 FLA_Blocksize_free( trinv_cntl_bsize );
297
298 FLA_Cntl_obj_free( trinv_cntl_unb );
299 FLA_Cntl_obj_free( trinv_cntl_opt );
300 FLA_Cntl_obj_free( trinv_cntl_blk );
301 }
302
303
304
libfla_test_trinv_impl(int impl,FLA_Uplo uplo,FLA_Diag diag,FLA_Obj A)305 void libfla_test_trinv_impl( int impl,
306 FLA_Uplo uplo,
307 FLA_Diag diag,
308 FLA_Obj A )
309 {
310 switch ( impl )
311 {
312 case FLA_TEST_HIER_FRONT_END:
313 FLASH_Trinv( uplo, diag, A );
314 break;
315
316 case FLA_TEST_FLAT_FRONT_END:
317 FLA_Trinv( uplo, diag, A );
318 break;
319
320 case FLA_TEST_FLAT_UNB_VAR:
321 FLA_Trinv_internal( uplo, diag, A, trinv_cntl_unb );
322 break;
323
324 case FLA_TEST_FLAT_OPT_VAR:
325 FLA_Trinv_internal( uplo, diag, A, trinv_cntl_opt );
326 break;
327
328 case FLA_TEST_FLAT_BLK_VAR:
329 FLA_Trinv_internal( uplo, diag, A, trinv_cntl_blk );
330 break;
331
332 default:
333 libfla_test_output_error( "Invalid implementation type.\n" );
334 }
335 }
336
337