1 /*
2 
3     Copyright (C) 2014, The University of Texas at Austin
4 
5     This file is part of libflame and is available under the 3-Clause
6     BSD license, which can be found in the LICENSE file at the top-level
7     directory, or at http://opensource.org/licenses/BSD-3-Clause
8 
9 */
10 
11 #include "FLAME.h"
12 #include "test_libflame.h"
13 
14 #define NUM_PARAM_COMBOS 1
15 #define NUM_MATRIX_ARGS  6
16 #define FIRST_VARIANT    1
17 #define LAST_VARIANT     1
18 
19 // Static variables.
20 static char* op_str                   = "Apply up/downdating Q via UD UT transform";
21 static char* fla_front_str            = "FLA_Apply_QUD_UT";
22 //static char* fla_unb_var_str          = "";
23 //static char* fla_opt_var_str          = "";
24 //static char* fla_blk_var_str          = "";
25 static char* pc_str[NUM_PARAM_COMBOS] = { "" };
26 static test_thresh_t thresh           = { 1e-02, 1e-03,   // warn, pass for s
27                                           1e-11, 1e-12,   // warn, pass for d
28                                           1e-02, 1e-03,   // warn, pass for c
29                                           1e-11, 1e-12 }; // warn, pass for z
30 
31 // Local prototypes.
32 void libfla_test_apqudut_experiment( test_params_t params,
33                                      unsigned int  var,
34                                      char*         sc_str,
35                                      FLA_Datatype  datatype,
36                                      unsigned int  p,
37                                      unsigned int  pci,
38                                      unsigned int  n_repeats,
39                                      signed int    impl,
40                                      double*       perf,
41                                      double*       residual );
42 void libfla_test_apqudut_impl( int     impl,
43                                FLA_Obj T, FLA_Obj W,
44                                           FLA_Obj bR,
45                                FLA_Obj C, FLA_Obj bC,
46                                FLA_Obj D, FLA_Obj bD );
47 
libfla_test_apqudut(FILE * output_stream,test_params_t params,test_op_t op)48 void libfla_test_apqudut( FILE* output_stream, test_params_t params, test_op_t op )
49 {
50 	libfla_test_output_info( "--- %s ---\n", op_str );
51 	libfla_test_output_info( "\n" );
52 
53 	if ( op.fla_front == ENABLE )
54 	{
55 		libfla_test_op_driver( fla_front_str, NULL,
56 		                       FIRST_VARIANT, LAST_VARIANT,
57 		                       NUM_PARAM_COMBOS, pc_str,
58 		                       NUM_MATRIX_ARGS,
59 		                       FLA_TEST_FLAT_FRONT_END,
60 		                       params, thresh, libfla_test_apqudut_experiment );
61 	}
62 
63 }
64 
65 
66 
libfla_test_apqudut_experiment(test_params_t params,unsigned int var,char * sc_str,FLA_Datatype datatype,unsigned int p_cur,unsigned int pci,unsigned int n_repeats,signed int impl,double * perf,double * residual)67 void libfla_test_apqudut_experiment( test_params_t params,
68                                      unsigned int  var,
69                                      char*         sc_str,
70                                      FLA_Datatype  datatype,
71                                      unsigned int  p_cur,
72                                      unsigned int  pci,
73                                      unsigned int  n_repeats,
74                                      signed int    impl,
75                                      double*       perf,
76                                      double*       residual )
77 {
78 	dim_t        b_alg_flat = params.b_alg_flat;
79 	double       time_min   = 1e9;
80 	double       time;
81 	unsigned int i;
82 	unsigned int mB, mC, mD, n, n_rhs;
83 	signed int   mB_input    = -1;
84 	signed int   mC_input    = -4;
85 	signed int   mD_input    = -4;
86 	signed int   n_input     = -1;
87 	signed int   n_rhs_input = -1;
88 	FLA_Obj      R_BD, R_BC, B, C, D, T, W;
89 	FLA_Obj      bR_BD, bR_BC, bB, bC, bD;
90 	FLA_Obj      bR_BD_save, bC_save, bD_save;
91 
92 	// Determine the dimensions.
93 	if ( mB_input    < 0 ) mB    = p_cur / abs(mB_input);
94 	else                   mB    = p_cur;
95 	if ( mC_input    < 0 ) mC    = p_cur / abs(mC_input);
96 	else                   mC    = p_cur;
97 	if ( mD_input    < 0 ) mD    = p_cur / abs(mD_input);
98 	else                   mD    = p_cur;
99 	if ( n_input     < 0 ) n     = p_cur / abs(n_input);
100 	else                   n     = p_cur;
101 	if ( n_rhs_input < 0 ) n_rhs = p_cur / abs(n_rhs_input);
102 	else                   n_rhs = p_cur;
103 
104 	// Create the matrices for the current operation.
105 	libfla_test_obj_create( datatype, FLA_NO_TRANSPOSE, sc_str[0], b_alg_flat, n, &T );
106 
107 	libfla_test_obj_create( datatype, FLA_NO_TRANSPOSE, sc_str[0], mB, n, &B );
108 	libfla_test_obj_create( datatype, FLA_NO_TRANSPOSE, sc_str[1], mC, n, &C );
109 	libfla_test_obj_create( datatype, FLA_NO_TRANSPOSE, sc_str[2], mD, n, &D );
110 	libfla_test_obj_create( datatype, FLA_NO_TRANSPOSE, sc_str[0], n,  n, &R_BC );
111 	libfla_test_obj_create( datatype, FLA_NO_TRANSPOSE, sc_str[0], n,  n, &R_BD );
112 
113 	libfla_test_obj_create( datatype, FLA_NO_TRANSPOSE, sc_str[0], mB, n_rhs, &bB );
114 	libfla_test_obj_create( datatype, FLA_NO_TRANSPOSE, sc_str[4], mC, n_rhs, &bC );
115 	libfla_test_obj_create( datatype, FLA_NO_TRANSPOSE, sc_str[5], mD, n_rhs, &bD );
116 	libfla_test_obj_create( datatype, FLA_NO_TRANSPOSE, sc_str[3], n,  n_rhs, &bR_BC );
117 	libfla_test_obj_create( datatype, FLA_NO_TRANSPOSE, sc_str[3], n,  n_rhs, &bR_BD );
118 
119 	FLA_Apply_QUD_UT_create_workspace( T, bR_BD, &W );
120 
121 	// Initialize the test matrices.
122 	FLA_Random_matrix( B );
123 	FLA_Random_matrix( C );
124 	FLA_Random_matrix( D );
125 
126 	// Initialize the right-hand sides.
127 	FLA_Random_matrix( bB );
128 	FLA_Random_matrix( bC );
129 	FLA_Random_matrix( bD );
130 
131 	// Intialize the test factorization.
132 	FLA_Set( FLA_ZERO, R_BD );
133 	FLA_Herk_external( FLA_UPPER_TRIANGULAR, FLA_CONJ_TRANSPOSE, FLA_ONE, B, FLA_ONE, R_BD );
134 	FLA_Herk_external( FLA_UPPER_TRIANGULAR, FLA_CONJ_TRANSPOSE, FLA_ONE, D, FLA_ONE, R_BD );
135 	FLA_Chol( FLA_UPPER_TRIANGULAR, R_BD );
136 
137 	// Initialize the solution factorization.
138 	FLA_Set( FLA_ZERO, R_BC );
139 	FLA_Herk_external( FLA_UPPER_TRIANGULAR, FLA_CONJ_TRANSPOSE, FLA_ONE, B, FLA_ONE, R_BC );
140 	FLA_Herk_external( FLA_UPPER_TRIANGULAR, FLA_CONJ_TRANSPOSE, FLA_ONE, C, FLA_ONE, R_BC );
141 	FLA_Chol( FLA_UPPER_TRIANGULAR, R_BC );
142 
143 	// Initialize the test right-hand side.
144 	FLA_Gemm_external( FLA_CONJ_TRANSPOSE, FLA_NO_TRANSPOSE, FLA_ONE, B, bB, FLA_ZERO, bR_BD );
145 	FLA_Gemm_external( FLA_CONJ_TRANSPOSE, FLA_NO_TRANSPOSE, FLA_ONE, D, bD, FLA_ONE,  bR_BD );
146 	FLA_Trsm_external( FLA_LEFT, FLA_UPPER_TRIANGULAR, FLA_CONJ_TRANSPOSE, FLA_NONUNIT_DIAG, FLA_ONE, R_BD, bR_BD );
147 
148 	// Initialize the solution right-hand side.
149 	FLA_Gemm_external( FLA_CONJ_TRANSPOSE, FLA_NO_TRANSPOSE, FLA_ONE, B, bB, FLA_ZERO, bR_BC );
150 	FLA_Gemm_external( FLA_CONJ_TRANSPOSE, FLA_NO_TRANSPOSE, FLA_ONE, C, bC, FLA_ONE,  bR_BC );
151 	FLA_Trsm_external( FLA_LEFT, FLA_UPPER_TRIANGULAR, FLA_CONJ_TRANSPOSE, FLA_NONUNIT_DIAG, FLA_ONE, R_BC, bR_BC );
152 
153 	// Perform the up/downdate on R_BD, C, D, and T.
154 	FLA_UDdate_UT( R_BD, C, D, T );
155 
156 	// Save the original test right-hand sides to temporary objects.
157 	FLA_Obj_create_copy_of( FLA_NO_TRANSPOSE, bR_BD, &bR_BD_save );
158 	FLA_Obj_create_copy_of( FLA_NO_TRANSPOSE, bC, &bC_save );
159 	FLA_Obj_create_copy_of( FLA_NO_TRANSPOSE, bD, &bD_save );
160 
161 	// Repeat the experiment n_repeats times and record results.
162 	for ( i = 0; i < n_repeats; ++i )
163 	{
164 		FLA_Copy_external( bR_BD_save, bR_BD );
165 		FLA_Copy_external( bC_save, bC );
166 		FLA_Copy_external( bD_save, bD );
167 
168 		time = FLA_Clock();
169 
170 		libfla_test_apqudut_impl( impl, T, W,
171 		                                   bR_BD,
172 		                                C, bC,
173 		                                D, bD );
174 
175 		time = FLA_Clock() - time;
176 		time_min = min( time_min, time );
177 	}
178 
179 	// Solve for the solutions of our two systems.
180 	FLA_Trsm_external( FLA_LEFT, FLA_UPPER_TRIANGULAR, FLA_NO_TRANSPOSE, FLA_NONUNIT_DIAG,
181 	                   FLA_ONE, R_BD, bR_BD );
182 	FLA_Trsm_external( FLA_LEFT, FLA_UPPER_TRIANGULAR, FLA_NO_TRANSPOSE, FLA_NONUNIT_DIAG,
183 	                   FLA_ONE, R_BC, bR_BC );
184 
185 	// Compute the maximum element-wise difference between the solutions of
186 	// the two systems.
187 	*residual = FLA_Max_elemwise_diff( bR_BD, bR_BC );
188 
189 	// Compute the performance of the best experiment repeat.
190 	*perf = n * n_rhs * ( 2.0 * mC + 2.0 * mD + 0.5 * b_alg_flat + 0.5 ) /
191 	        time_min / FLOPS_PER_UNIT_PERF;
192 	if ( FLA_Obj_is_complex( bR_BD ) ) *perf *= 4.0;
193 
194 	// Free the supporting flat objects.
195 	FLA_Obj_free( &bR_BD_save );
196 	FLA_Obj_free( &bC_save );
197 	FLA_Obj_free( &bD_save );
198 
199 	// Free the flat test matrices.
200 	FLA_Obj_free( &B );
201 	FLA_Obj_free( &C );
202 	FLA_Obj_free( &D );
203 	FLA_Obj_free( &R_BC );
204 	FLA_Obj_free( &R_BD );
205 	FLA_Obj_free( &bB );
206 	FLA_Obj_free( &bC );
207 	FLA_Obj_free( &bD );
208 	FLA_Obj_free( &bR_BC );
209 	FLA_Obj_free( &bR_BD );
210 	FLA_Obj_free( &T );
211 	FLA_Obj_free( &W );
212 }
213 
214 
215 
libfla_test_apqudut_impl(int impl,FLA_Obj T,FLA_Obj W,FLA_Obj bR_BD,FLA_Obj C,FLA_Obj bC,FLA_Obj D,FLA_Obj bD)216 void libfla_test_apqudut_impl( int     impl,
217                                FLA_Obj T, FLA_Obj W,
218                                           FLA_Obj bR_BD,
219                                FLA_Obj C, FLA_Obj bC,
220                                FLA_Obj D, FLA_Obj bD )
221 {
222 	switch ( impl )
223 	{
224 		case FLA_TEST_FLAT_FRONT_END:
225 		FLA_Apply_QUD_UT( FLA_LEFT, FLA_CONJ_TRANSPOSE, FLA_FORWARD, FLA_COLUMNWISE,
226 		                  T, W,
227 		                     bR_BD,
228 		                  C, bC,
229 		                  D, bD );
230 		break;
231 
232 		default:
233 		libfla_test_output_error( "Invalid implementation type.\n" );
234 	}
235 }
236 
237