1 /*
2
3 Copyright (C) 2014, The University of Texas at Austin
4
5 This file is part of libflame and is available under the 3-Clause
6 BSD license, which can be found in the LICENSE file at the top-level
7 directory, or at http://opensource.org/licenses/BSD-3-Clause
8
9 */
10
11 #include "FLAME.h"
12 #include "test_libflame.h"
13
14 #define NUM_PARAM_COMBOS 1
15 #define NUM_MATRIX_ARGS 6
16 #define FIRST_VARIANT 1
17 #define LAST_VARIANT 1
18
19 // Static variables.
20 static char* op_str = "Apply up/downdating Q via UD UT transform";
21 static char* fla_front_str = "FLA_Apply_QUD_UT";
22 //static char* fla_unb_var_str = "";
23 //static char* fla_opt_var_str = "";
24 //static char* fla_blk_var_str = "";
25 static char* pc_str[NUM_PARAM_COMBOS] = { "" };
26 static test_thresh_t thresh = { 1e-02, 1e-03, // warn, pass for s
27 1e-11, 1e-12, // warn, pass for d
28 1e-02, 1e-03, // warn, pass for c
29 1e-11, 1e-12 }; // warn, pass for z
30
31 // Local prototypes.
32 void libfla_test_apqudut_experiment( test_params_t params,
33 unsigned int var,
34 char* sc_str,
35 FLA_Datatype datatype,
36 unsigned int p,
37 unsigned int pci,
38 unsigned int n_repeats,
39 signed int impl,
40 double* perf,
41 double* residual );
42 void libfla_test_apqudut_impl( int impl,
43 FLA_Obj T, FLA_Obj W,
44 FLA_Obj bR,
45 FLA_Obj C, FLA_Obj bC,
46 FLA_Obj D, FLA_Obj bD );
47
libfla_test_apqudut(FILE * output_stream,test_params_t params,test_op_t op)48 void libfla_test_apqudut( FILE* output_stream, test_params_t params, test_op_t op )
49 {
50 libfla_test_output_info( "--- %s ---\n", op_str );
51 libfla_test_output_info( "\n" );
52
53 if ( op.fla_front == ENABLE )
54 {
55 libfla_test_op_driver( fla_front_str, NULL,
56 FIRST_VARIANT, LAST_VARIANT,
57 NUM_PARAM_COMBOS, pc_str,
58 NUM_MATRIX_ARGS,
59 FLA_TEST_FLAT_FRONT_END,
60 params, thresh, libfla_test_apqudut_experiment );
61 }
62
63 }
64
65
66
libfla_test_apqudut_experiment(test_params_t params,unsigned int var,char * sc_str,FLA_Datatype datatype,unsigned int p_cur,unsigned int pci,unsigned int n_repeats,signed int impl,double * perf,double * residual)67 void libfla_test_apqudut_experiment( test_params_t params,
68 unsigned int var,
69 char* sc_str,
70 FLA_Datatype datatype,
71 unsigned int p_cur,
72 unsigned int pci,
73 unsigned int n_repeats,
74 signed int impl,
75 double* perf,
76 double* residual )
77 {
78 dim_t b_alg_flat = params.b_alg_flat;
79 double time_min = 1e9;
80 double time;
81 unsigned int i;
82 unsigned int mB, mC, mD, n, n_rhs;
83 signed int mB_input = -1;
84 signed int mC_input = -4;
85 signed int mD_input = -4;
86 signed int n_input = -1;
87 signed int n_rhs_input = -1;
88 FLA_Obj R_BD, R_BC, B, C, D, T, W;
89 FLA_Obj bR_BD, bR_BC, bB, bC, bD;
90 FLA_Obj bR_BD_save, bC_save, bD_save;
91
92 // Determine the dimensions.
93 if ( mB_input < 0 ) mB = p_cur / abs(mB_input);
94 else mB = p_cur;
95 if ( mC_input < 0 ) mC = p_cur / abs(mC_input);
96 else mC = p_cur;
97 if ( mD_input < 0 ) mD = p_cur / abs(mD_input);
98 else mD = p_cur;
99 if ( n_input < 0 ) n = p_cur / abs(n_input);
100 else n = p_cur;
101 if ( n_rhs_input < 0 ) n_rhs = p_cur / abs(n_rhs_input);
102 else n_rhs = p_cur;
103
104 // Create the matrices for the current operation.
105 libfla_test_obj_create( datatype, FLA_NO_TRANSPOSE, sc_str[0], b_alg_flat, n, &T );
106
107 libfla_test_obj_create( datatype, FLA_NO_TRANSPOSE, sc_str[0], mB, n, &B );
108 libfla_test_obj_create( datatype, FLA_NO_TRANSPOSE, sc_str[1], mC, n, &C );
109 libfla_test_obj_create( datatype, FLA_NO_TRANSPOSE, sc_str[2], mD, n, &D );
110 libfla_test_obj_create( datatype, FLA_NO_TRANSPOSE, sc_str[0], n, n, &R_BC );
111 libfla_test_obj_create( datatype, FLA_NO_TRANSPOSE, sc_str[0], n, n, &R_BD );
112
113 libfla_test_obj_create( datatype, FLA_NO_TRANSPOSE, sc_str[0], mB, n_rhs, &bB );
114 libfla_test_obj_create( datatype, FLA_NO_TRANSPOSE, sc_str[4], mC, n_rhs, &bC );
115 libfla_test_obj_create( datatype, FLA_NO_TRANSPOSE, sc_str[5], mD, n_rhs, &bD );
116 libfla_test_obj_create( datatype, FLA_NO_TRANSPOSE, sc_str[3], n, n_rhs, &bR_BC );
117 libfla_test_obj_create( datatype, FLA_NO_TRANSPOSE, sc_str[3], n, n_rhs, &bR_BD );
118
119 FLA_Apply_QUD_UT_create_workspace( T, bR_BD, &W );
120
121 // Initialize the test matrices.
122 FLA_Random_matrix( B );
123 FLA_Random_matrix( C );
124 FLA_Random_matrix( D );
125
126 // Initialize the right-hand sides.
127 FLA_Random_matrix( bB );
128 FLA_Random_matrix( bC );
129 FLA_Random_matrix( bD );
130
131 // Intialize the test factorization.
132 FLA_Set( FLA_ZERO, R_BD );
133 FLA_Herk_external( FLA_UPPER_TRIANGULAR, FLA_CONJ_TRANSPOSE, FLA_ONE, B, FLA_ONE, R_BD );
134 FLA_Herk_external( FLA_UPPER_TRIANGULAR, FLA_CONJ_TRANSPOSE, FLA_ONE, D, FLA_ONE, R_BD );
135 FLA_Chol( FLA_UPPER_TRIANGULAR, R_BD );
136
137 // Initialize the solution factorization.
138 FLA_Set( FLA_ZERO, R_BC );
139 FLA_Herk_external( FLA_UPPER_TRIANGULAR, FLA_CONJ_TRANSPOSE, FLA_ONE, B, FLA_ONE, R_BC );
140 FLA_Herk_external( FLA_UPPER_TRIANGULAR, FLA_CONJ_TRANSPOSE, FLA_ONE, C, FLA_ONE, R_BC );
141 FLA_Chol( FLA_UPPER_TRIANGULAR, R_BC );
142
143 // Initialize the test right-hand side.
144 FLA_Gemm_external( FLA_CONJ_TRANSPOSE, FLA_NO_TRANSPOSE, FLA_ONE, B, bB, FLA_ZERO, bR_BD );
145 FLA_Gemm_external( FLA_CONJ_TRANSPOSE, FLA_NO_TRANSPOSE, FLA_ONE, D, bD, FLA_ONE, bR_BD );
146 FLA_Trsm_external( FLA_LEFT, FLA_UPPER_TRIANGULAR, FLA_CONJ_TRANSPOSE, FLA_NONUNIT_DIAG, FLA_ONE, R_BD, bR_BD );
147
148 // Initialize the solution right-hand side.
149 FLA_Gemm_external( FLA_CONJ_TRANSPOSE, FLA_NO_TRANSPOSE, FLA_ONE, B, bB, FLA_ZERO, bR_BC );
150 FLA_Gemm_external( FLA_CONJ_TRANSPOSE, FLA_NO_TRANSPOSE, FLA_ONE, C, bC, FLA_ONE, bR_BC );
151 FLA_Trsm_external( FLA_LEFT, FLA_UPPER_TRIANGULAR, FLA_CONJ_TRANSPOSE, FLA_NONUNIT_DIAG, FLA_ONE, R_BC, bR_BC );
152
153 // Perform the up/downdate on R_BD, C, D, and T.
154 FLA_UDdate_UT( R_BD, C, D, T );
155
156 // Save the original test right-hand sides to temporary objects.
157 FLA_Obj_create_copy_of( FLA_NO_TRANSPOSE, bR_BD, &bR_BD_save );
158 FLA_Obj_create_copy_of( FLA_NO_TRANSPOSE, bC, &bC_save );
159 FLA_Obj_create_copy_of( FLA_NO_TRANSPOSE, bD, &bD_save );
160
161 // Repeat the experiment n_repeats times and record results.
162 for ( i = 0; i < n_repeats; ++i )
163 {
164 FLA_Copy_external( bR_BD_save, bR_BD );
165 FLA_Copy_external( bC_save, bC );
166 FLA_Copy_external( bD_save, bD );
167
168 time = FLA_Clock();
169
170 libfla_test_apqudut_impl( impl, T, W,
171 bR_BD,
172 C, bC,
173 D, bD );
174
175 time = FLA_Clock() - time;
176 time_min = min( time_min, time );
177 }
178
179 // Solve for the solutions of our two systems.
180 FLA_Trsm_external( FLA_LEFT, FLA_UPPER_TRIANGULAR, FLA_NO_TRANSPOSE, FLA_NONUNIT_DIAG,
181 FLA_ONE, R_BD, bR_BD );
182 FLA_Trsm_external( FLA_LEFT, FLA_UPPER_TRIANGULAR, FLA_NO_TRANSPOSE, FLA_NONUNIT_DIAG,
183 FLA_ONE, R_BC, bR_BC );
184
185 // Compute the maximum element-wise difference between the solutions of
186 // the two systems.
187 *residual = FLA_Max_elemwise_diff( bR_BD, bR_BC );
188
189 // Compute the performance of the best experiment repeat.
190 *perf = n * n_rhs * ( 2.0 * mC + 2.0 * mD + 0.5 * b_alg_flat + 0.5 ) /
191 time_min / FLOPS_PER_UNIT_PERF;
192 if ( FLA_Obj_is_complex( bR_BD ) ) *perf *= 4.0;
193
194 // Free the supporting flat objects.
195 FLA_Obj_free( &bR_BD_save );
196 FLA_Obj_free( &bC_save );
197 FLA_Obj_free( &bD_save );
198
199 // Free the flat test matrices.
200 FLA_Obj_free( &B );
201 FLA_Obj_free( &C );
202 FLA_Obj_free( &D );
203 FLA_Obj_free( &R_BC );
204 FLA_Obj_free( &R_BD );
205 FLA_Obj_free( &bB );
206 FLA_Obj_free( &bC );
207 FLA_Obj_free( &bD );
208 FLA_Obj_free( &bR_BC );
209 FLA_Obj_free( &bR_BD );
210 FLA_Obj_free( &T );
211 FLA_Obj_free( &W );
212 }
213
214
215
libfla_test_apqudut_impl(int impl,FLA_Obj T,FLA_Obj W,FLA_Obj bR_BD,FLA_Obj C,FLA_Obj bC,FLA_Obj D,FLA_Obj bD)216 void libfla_test_apqudut_impl( int impl,
217 FLA_Obj T, FLA_Obj W,
218 FLA_Obj bR_BD,
219 FLA_Obj C, FLA_Obj bC,
220 FLA_Obj D, FLA_Obj bD )
221 {
222 switch ( impl )
223 {
224 case FLA_TEST_FLAT_FRONT_END:
225 FLA_Apply_QUD_UT( FLA_LEFT, FLA_CONJ_TRANSPOSE, FLA_FORWARD, FLA_COLUMNWISE,
226 T, W,
227 bR_BD,
228 C, bC,
229 D, bD );
230 break;
231
232 default:
233 libfla_test_output_error( "Invalid implementation type.\n" );
234 }
235 }
236
237