1 /*
2
3 Copyright (C) 2014, The University of Texas at Austin
4
5 This file is part of libflame and is available under the 3-Clause
6 BSD license, which can be found in the LICENSE file at the top-level
7 directory, or at http://opensource.org/licenses/BSD-3-Clause
8
9 */
10
11 #include "FLAME.h"
12 #include "test_libflame.h"
13
14 #define NUM_PARAM_COMBOS 32
15 #define NUM_MATRIX_ARGS 2
16 #define FIRST_VARIANT 1
17 #define LAST_VARIANT 4
18
19 // Static variables.
20 static char* op_str = "Triangular solve with multiple rhs";
21 static char* flash_front_str = "FLASH_Trsm";
22 static char* fla_front_str = "FLA_Trsm";
23 static char* fla_unb_var_str = "unb_var";
24 //static char* fla_opt_var_str = "opt_var";
25 static char* fla_blk_var_str = "blk_var";
26 static char* fla_unb_ext_str = "external";
27 static char* pc_str[NUM_PARAM_COMBOS] = { "llnn", "llnu",
28 "llcn", "llcu",
29 "lltn", "lltu",
30 "llhn", "llhu",
31 "lunn", "lunu",
32 "lucn", "lucu",
33 "lutn", "lutu",
34 "luhn", "luhu",
35 "rlnn", "rlnu",
36 "rlcn", "rlcu",
37 "rltn", "rltu",
38 "rlhn", "rlhu",
39 "runn", "runu",
40 "rucn", "rucu",
41 "rutn", "rutu",
42 "ruhn", "ruhu" };
43 static test_thresh_t thresh = { 1e-02, 1e-03, // warn, pass for s
44 1e-11, 1e-12, // warn, pass for d
45 1e-02, 1e-03, // warn, pass for c
46 1e-11, 1e-12 }; // warn, pass for z
47
48 static fla_trsm_t* trsm_cntl_unb;
49 static fla_trsm_t* trsm_cntl_blk;
50 static fla_blocksize_t* trsm_cntl_bsize;
51
52 // Local prototypes.
53 void libfla_test_trsm_experiment( test_params_t params,
54 unsigned int var,
55 char* sc_str,
56 FLA_Datatype datatype,
57 unsigned int p_cur,
58 unsigned int pci,
59 unsigned int n_repeats,
60 signed int impl,
61 double* perf,
62 double* residual );
63 void libfla_test_trsm_impl( int impl,
64 FLA_Side side,
65 FLA_Uplo uplo,
66 FLA_Trans trans,
67 FLA_Diag diag,
68 FLA_Obj alpha,
69 FLA_Obj A,
70 FLA_Obj B );
71 void libfla_test_trsm_cntl_create( unsigned int var,
72 dim_t b_alg_flat );
73 void libfla_test_trsm_cntl_free( void );
74
75
libfla_test_trsm(FILE * output_stream,test_params_t params,test_op_t op)76 void libfla_test_trsm( FILE* output_stream, test_params_t params, test_op_t op )
77 {
78 libfla_test_output_info( "--- %s ---\n", op_str );
79 libfla_test_output_info( "\n" );
80
81 if ( op.flash_front == ENABLE )
82 {
83 libfla_test_op_driver( flash_front_str, NULL,
84 FIRST_VARIANT, LAST_VARIANT,
85 NUM_PARAM_COMBOS, pc_str,
86 NUM_MATRIX_ARGS,
87 FLA_TEST_HIER_FRONT_END,
88 params, thresh, libfla_test_trsm_experiment );
89 }
90
91 if ( op.fla_front == ENABLE )
92 {
93 libfla_test_op_driver( fla_front_str, NULL,
94 FIRST_VARIANT, LAST_VARIANT,
95 NUM_PARAM_COMBOS, pc_str,
96 NUM_MATRIX_ARGS,
97 FLA_TEST_FLAT_FRONT_END,
98 params, thresh, libfla_test_trsm_experiment );
99 }
100
101 if ( op.fla_unb_vars == ENABLE )
102 {
103 libfla_test_op_driver( fla_front_str, fla_unb_var_str,
104 FIRST_VARIANT, LAST_VARIANT,
105 NUM_PARAM_COMBOS, pc_str,
106 NUM_MATRIX_ARGS,
107 FLA_TEST_FLAT_UNB_VAR,
108 params, thresh, libfla_test_trsm_experiment );
109 }
110
111 if ( op.fla_blk_vars == ENABLE )
112 {
113 libfla_test_op_driver( fla_front_str, fla_blk_var_str,
114 FIRST_VARIANT, LAST_VARIANT,
115 NUM_PARAM_COMBOS, pc_str,
116 NUM_MATRIX_ARGS,
117 FLA_TEST_FLAT_BLK_VAR,
118 params, thresh, libfla_test_trsm_experiment );
119 }
120
121 if ( op.fla_unb_ext == ENABLE )
122 {
123 libfla_test_op_driver( fla_front_str, fla_unb_ext_str,
124 FIRST_VARIANT, LAST_VARIANT,
125 NUM_PARAM_COMBOS, pc_str,
126 NUM_MATRIX_ARGS,
127 FLA_TEST_FLAT_UNB_EXT,
128 params, thresh, libfla_test_trsm_experiment );
129 }
130 }
131
132
133
libfla_test_trsm_experiment(test_params_t params,unsigned int var,char * sc_str,FLA_Datatype datatype,unsigned int p_cur,unsigned int pci,unsigned int n_repeats,signed int impl,double * perf,double * residual)134 void libfla_test_trsm_experiment( test_params_t params,
135 unsigned int var,
136 char* sc_str,
137 FLA_Datatype datatype,
138 unsigned int p_cur,
139 unsigned int pci,
140 unsigned int n_repeats,
141 signed int impl,
142 double* perf,
143 double* residual )
144 {
145 dim_t b_flash = params.b_flash;
146 dim_t b_alg_flat = params.b_alg_flat;
147 double time_min = 1e9;
148 double time;
149 unsigned int i;
150 unsigned int m;
151 signed int m_input = -1;
152 unsigned int n;
153 signed int n_input = -1;
154 FLA_Side side;
155 FLA_Uplo uplo;
156 FLA_Trans trans;
157 FLA_Diag diag;
158 FLA_Obj A, B, x, y, z, norm;
159 FLA_Obj alpha;
160 FLA_Obj B_save;
161 FLA_Obj A_test, B_test;
162
163 // Determine the dimensions.
164 if ( m_input < 0 ) m = p_cur / abs(m_input);
165 else m = p_cur;
166 if ( n_input < 0 ) n = p_cur / abs(n_input);
167 else n = p_cur;
168
169 // Translate parameter characters to libflame constants.
170 FLA_Param_map_char_to_flame_side( &pc_str[pci][0], &side );
171 FLA_Param_map_char_to_flame_uplo( &pc_str[pci][1], &uplo );
172 FLA_Param_map_char_to_flame_trans( &pc_str[pci][2], &trans );
173 FLA_Param_map_char_to_flame_diag( &pc_str[pci][3], &diag );
174
175 // Create the matrices for the current operation.
176 if ( side == FLA_LEFT )
177 libfla_test_obj_create( datatype, FLA_NO_TRANSPOSE, sc_str[0], m, m, &A );
178 else
179 libfla_test_obj_create( datatype, FLA_NO_TRANSPOSE, sc_str[0], n, n, &A );
180 libfla_test_obj_create( datatype, FLA_NO_TRANSPOSE, sc_str[1], m, n, &B );
181
182 // Create vectors for use in test.
183 FLA_Obj_create( datatype, n, 1, 0, 0, &x );
184 FLA_Obj_create( datatype, m, 1, 0, 0, &y );
185 FLA_Obj_create( datatype, m, 1, 0, 0, &z );
186
187 // Create a norm scalar.
188 FLA_Obj_create( FLA_Obj_datatype_proj_to_real( A ), 1, 1, 0, 0, &norm );
189
190 // Initialize the test matrices.
191 FLA_Random_tri_matrix( uplo, diag, A );
192 FLA_Random_matrix( B );
193
194 // Initialize the test vectors.
195 FLA_Random_matrix( x );
196 FLA_Set( FLA_ZERO, y );
197 FLA_Set( FLA_ZERO, z );
198
199 // Set constants.
200 alpha = FLA_TWO;
201
202 // Save the original object contents in a temporary object.
203 FLA_Obj_create_copy_of( FLA_NO_TRANSPOSE, B, &B_save );
204
205 // Use hierarchical matrices if we're testing the FLASH front-end.
206 if ( impl == FLA_TEST_HIER_FRONT_END )
207 {
208 FLASH_Obj_create_hier_copy_of_flat( A, 1, &b_flash, &A_test );
209 FLASH_Obj_create_hier_copy_of_flat( B, 1, &b_flash, &B_test );
210 }
211 else
212 {
213 A_test = A;
214 B_test = B;
215 }
216
217 // Create a control tree for the individual variants.
218 if ( impl == FLA_TEST_FLAT_UNB_VAR ||
219 impl == FLA_TEST_FLAT_OPT_VAR ||
220 impl == FLA_TEST_FLAT_BLK_VAR ||
221 impl == FLA_TEST_FLAT_UNB_EXT ||
222 impl == FLA_TEST_FLAT_BLK_EXT )
223 libfla_test_trsm_cntl_create( var, b_alg_flat );
224
225 // Repeat the experiment n_repeats times and record results.
226 for ( i = 0; i < n_repeats; ++i )
227 {
228 if ( impl == FLA_TEST_HIER_FRONT_END )
229 FLASH_Obj_hierarchify( B_save, B_test );
230 else
231 FLA_Copy_external( B_save, B_test );
232
233 time = FLA_Clock();
234
235 libfla_test_trsm_impl( impl, side, uplo, trans, diag, alpha, A_test, B_test );
236
237 time = FLA_Clock() - time;
238 time_min = min( time_min, time );
239 }
240
241 // Copy the solution to flat matrix X.
242 if ( impl == FLA_TEST_HIER_FRONT_END )
243 {
244 FLASH_Obj_flatten( B_test, B );
245 }
246 else
247 {
248 // No action needed since B_test and B refer to the same object.
249 }
250
251 // Free the hierarchical matrices if we're testing the FLASH front-end.
252 if ( impl == FLA_TEST_HIER_FRONT_END )
253 {
254 FLASH_Obj_free( &A_test );
255 FLASH_Obj_free( &B_test );
256 }
257
258 // Free the control trees if we're testing the variants.
259 if ( impl == FLA_TEST_FLAT_UNB_VAR ||
260 impl == FLA_TEST_FLAT_OPT_VAR ||
261 impl == FLA_TEST_FLAT_BLK_VAR ||
262 impl == FLA_TEST_FLAT_UNB_EXT ||
263 impl == FLA_TEST_FLAT_BLK_EXT )
264 libfla_test_trsm_cntl_free();
265
266 // Compute the performance of the best experiment repeat.
267 if ( side == FLA_LEFT )
268 *perf = ( 1 * m * m * n ) / time_min / FLOPS_PER_UNIT_PERF;
269 else
270 *perf = ( 1 * m * n * n ) / time_min / FLOPS_PER_UNIT_PERF;
271 if ( FLA_Obj_is_complex( A ) ) *perf *= 4.0;
272
273 // Compute:
274 // y = B * x
275 // and compare to
276 // z = ( alpha * inv(A) * B ) x (side = left)
277 // z = ( alpha * B * inv(A) ) x (side = right)
278 FLA_Gemv_external( FLA_NO_TRANSPOSE, FLA_ONE, B, x, FLA_ZERO, y );
279
280 if ( side == FLA_LEFT )
281 {
282 FLA_Gemv_external( FLA_NO_TRANSPOSE, alpha, B_save, x, FLA_ZERO, z );
283 FLA_Trsv_external( uplo, trans, diag, A, z );
284 }
285 else
286 {
287 FLA_Trsv_external( uplo, trans, diag, A, x );
288 FLA_Gemv_external( FLA_NO_TRANSPOSE, alpha, B_save, x, FLA_ZERO, z );
289 }
290
291 // Compute || y - z ||.
292 //FLA_Axpy_external( FLA_MINUS_ONE, y, z );
293 //FLA_Nrm2_external( z, norm );
294 //FLA_Obj_extract_real_scalar( norm, residual );
295 *residual = FLA_Max_elemwise_diff( y, z );
296
297 // Free the supporting flat objects.
298 FLA_Obj_free( &B_save );
299
300 // Free the flat test matrices.
301 FLA_Obj_free( &A );
302 FLA_Obj_free( &B );
303 FLA_Obj_free( &x );
304 FLA_Obj_free( &y );
305 FLA_Obj_free( &z );
306 FLA_Obj_free( &norm );
307 }
308
309
310
311 extern fla_scal_t* fla_scal_cntl_blas;
312 extern fla_gemm_t* fla_gemm_cntl_blas;
313 extern fla_trsm_t* fla_trsm_cntl_blas;
314
libfla_test_trsm_cntl_create(unsigned int var,dim_t b_alg_flat)315 void libfla_test_trsm_cntl_create( unsigned int var,
316 dim_t b_alg_flat )
317 {
318 int var_unb = FLA_UNB_VAR_OFFSET + var;
319 int var_blk = FLA_BLK_VAR_OFFSET + var;
320
321 trsm_cntl_bsize = FLA_Blocksize_create( b_alg_flat,
322 b_alg_flat,
323 b_alg_flat,
324 b_alg_flat );
325
326 trsm_cntl_unb = FLA_Cntl_trsm_obj_create( FLA_FLAT,
327 var_unb,
328 NULL,
329 NULL,
330 NULL,
331 NULL );
332
333 trsm_cntl_blk = FLA_Cntl_trsm_obj_create( FLA_FLAT,
334 var_blk,
335 trsm_cntl_bsize,
336 fla_scal_cntl_blas,
337 fla_trsm_cntl_blas,
338 fla_gemm_cntl_blas );
339 }
340
341
342
libfla_test_trsm_cntl_free(void)343 void libfla_test_trsm_cntl_free( void )
344 {
345 FLA_Blocksize_free( trsm_cntl_bsize );
346
347 FLA_Cntl_obj_free( trsm_cntl_unb );
348 FLA_Cntl_obj_free( trsm_cntl_blk );
349 }
350
351
352
libfla_test_trsm_impl(int impl,FLA_Side side,FLA_Uplo uplo,FLA_Trans trans,FLA_Diag diag,FLA_Obj alpha,FLA_Obj A,FLA_Obj B)353 void libfla_test_trsm_impl( int impl,
354 FLA_Side side,
355 FLA_Uplo uplo,
356 FLA_Trans trans,
357 FLA_Diag diag,
358 FLA_Obj alpha,
359 FLA_Obj A,
360 FLA_Obj B )
361 {
362 switch ( impl )
363 {
364 case FLA_TEST_HIER_FRONT_END:
365 FLASH_Trsm( side, uplo, trans, diag, alpha, A, B );
366 break;
367
368 case FLA_TEST_FLAT_FRONT_END:
369 FLA_Trsm( side, uplo, trans, diag, alpha, A, B );
370 break;
371
372 case FLA_TEST_FLAT_UNB_VAR:
373 FLA_Trsm_internal( side, uplo, trans, diag, alpha, A, B, trsm_cntl_unb );
374 break;
375
376 case FLA_TEST_FLAT_BLK_VAR:
377 FLA_Trsm_internal( side, uplo, trans, diag, alpha, A, B, trsm_cntl_blk );
378 break;
379
380 case FLA_TEST_FLAT_UNB_EXT:
381 FLA_Trsm_external( side, uplo, trans, diag, alpha, A, B );
382 break;
383
384 default:
385 libfla_test_output_error( "Invalid implementation type.\n" );
386 }
387 }
388
389