1 /*
2 
3     Copyright (C) 2014, The University of Texas at Austin
4 
5     This file is part of libflame and is available under the 3-Clause
6     BSD license, which can be found in the LICENSE file at the top-level
7     directory, or at http://opensource.org/licenses/BSD-3-Clause
8 
9 */
10 
11 #include "FLAME.h"
12 #include "test_libflame.h"
13 
14 //#define NUM_PARAM_COMBOS 16
15 #define NUM_PARAM_COMBOS 1
16 #define NUM_MATRIX_ARGS  2
17 #define FIRST_VARIANT    1
18 #define LAST_VARIANT     1
19 
20 // Static variables.
21 static char* op_str                   = "Apply CAQ via UT transform (incremental)";
22 static char* flash_front_str          = "FLASH_Apply_CAQ_UT_inc";
23 //static char* fla_front_str            = "FLA_Apply_Q_UT_inc";
24 //static char* fla_unb_var_str          = "";
25 //static char* fla_opt_var_str          = "";
26 //static char* fla_blk_var_str          = "";
27 //static char* pc_str[NUM_PARAM_COMBOS] = { "lnfc", "lnfr", "lnbc", "lnbr",
28 //                                          "lhfc", "lhfr", "lhbc", "lhbr",
29 //                                          "rnfc", "rnfr", "rnbc", "rnbr",
30 //                                          "rhfc", "rhfr", "rhbc", "rhbr" };
31 static char* pc_str[NUM_PARAM_COMBOS] = { "lhfc" };
32 static test_thresh_t thresh           = { 1e-02, 1e-03,   // warn, pass for s
33                                           1e-11, 1e-12,   // warn, pass for d
34                                           1e-02, 1e-03,   // warn, pass for c
35                                           1e-11, 1e-12 }; // warn, pass for z
36 
37 // Local prototypes.
38 void libfla_test_apcaqutinc_experiment( test_params_t params,
39                                         unsigned int  var,
40                                         char*         sc_str,
41                                         FLA_Datatype  datatype,
42                                         unsigned int  p,
43                                         unsigned int  pci,
44                                         unsigned int  n_repeats,
45                                         signed int    impl,
46                                         double*       perf,
47                                         double*       residual );
48 void libfla_test_apcaqutinc_impl( int        impl,
49                                   dim_t      p,
50                                   FLA_Side   side,
51                                   FLA_Trans  trans,
52                                   FLA_Direct direct,
53                                   FLA_Store  storev,
54                                   FLA_Obj    A,
55                                   FLA_Obj    ATW,
56                                   FLA_Obj    R,
57                                   FLA_Obj    RTW,
58                                   FLA_Obj    W,
59                                   FLA_Obj    B );
60 
libfla_test_apcaqutinc(FILE * output_stream,test_params_t params,test_op_t op)61 void libfla_test_apcaqutinc( FILE* output_stream, test_params_t params, test_op_t op )
62 {
63 	libfla_test_output_info( "--- %s ---\n", op_str );
64 	libfla_test_output_info( "\n" );
65 
66 	if ( op.flash_front == ENABLE )
67 	{
68 		libfla_test_op_driver( flash_front_str, NULL,
69 		                       FIRST_VARIANT, LAST_VARIANT,
70 		                       NUM_PARAM_COMBOS, pc_str,
71 		                       NUM_MATRIX_ARGS,
72 		                       FLA_TEST_HIER_FRONT_END,
73 		                       params, thresh, libfla_test_apcaqutinc_experiment );
74 	}
75 }
76 
77 
78 
libfla_test_apcaqutinc_experiment(test_params_t params,unsigned int var,char * sc_str,FLA_Datatype datatype,unsigned int p_cur,unsigned int pci,unsigned int n_repeats,signed int impl,double * perf,double * residual)79 void libfla_test_apcaqutinc_experiment( test_params_t params,
80                                         unsigned int  var,
81                                         char*         sc_str,
82                                         FLA_Datatype  datatype,
83                                         unsigned int  p_cur,
84                                         unsigned int  pci,
85                                         unsigned int  n_repeats,
86                                         signed int    impl,
87                                         double*       perf,
88                                         double*       residual )
89 {
90 	dim_t        b_flash    = params.b_flash;
91 	dim_t        b_alg_hier = params.b_alg_hier;
92 	double       time_min   = 1e9;
93 	double       time;
94 	unsigned int i;
95 	unsigned int m, n;
96 	unsigned int min_m_n, k;
97 	unsigned int p;
98 	signed int   m_input;
99 	signed int   n_input;
100 	FLA_Side     side;
101 	FLA_Trans    trans;
102 	FLA_Direct   direct;
103 	FLA_Store    storev;
104 	FLA_Obj      A, X, B, Y, norm;
105 	FLA_Obj      B_save;
106 	FLA_Obj      A_test, ATW_test;
107 	FLA_Obj      R_test, RTW_test;
108 	FLA_Obj      W_test, X_test, B_test;
109 
110 	// Translate parameter characters to libflame constants.
111 	FLA_Param_map_char_to_flame_side( &pc_str[pci][0], &side );
112 	FLA_Param_map_char_to_flame_trans( &pc_str[pci][1], &trans );
113 	FLA_Param_map_char_to_flame_direct( &pc_str[pci][2], &direct );
114 	FLA_Param_map_char_to_flame_storev( &pc_str[pci][3], &storev );
115 
116 	// We want to make sure the Apply_Q_UT_inc routines work with rectangular
117 	// matrices. So we use m > n when testing with column-wise storage (via
118 	// QR factorization) and m < n when testing with row-wise storage (via
119 	// LQ factorization).
120 	if ( storev == FLA_COLUMNWISE )
121 	{
122 		m_input = -8;
123 		n_input = -1;
124 		p       = 4;  // p <= abs(m_input) must hold!
125 		//m_input = -1;
126 		//n_input = -1;
127 	}
128 	else // if ( storev == FLA_ROWWISE )
129 	{
130 	}
131 
132 	// Determine the dimensions.
133 	if ( m_input < 0 ) m = p_cur * abs(m_input);
134 	else               m = p_cur;
135 	if ( n_input < 0 ) n = p_cur * abs(n_input);
136 	else               n = p_cur;
137 
138 	// Compute the minimum dimension.
139 	min_m_n = min( m, n );
140 
141 	// Choose the size of B based on the storev parameter.
142 	if ( storev == FLA_COLUMNWISE ) k = m;
143 	//else                            k = n;
144 
145 	// Create the matrices for the current operation.
146 	libfla_test_obj_create( datatype, FLA_NO_TRANSPOSE, sc_str[0], m, n, &A );
147 	libfla_test_obj_create( datatype, FLA_NO_TRANSPOSE, sc_str[1], m, 1, &B );
148 	libfla_test_obj_create( datatype, FLA_NO_TRANSPOSE, sc_str[1], n, 1, &X );
149 	libfla_test_obj_create( datatype, FLA_NO_TRANSPOSE, sc_str[1], n, 1, &Y );
150 
151 	// Initialize the test matrices.
152 	FLA_Random_matrix( A );
153 	FLA_Random_matrix( B );
154 
155 	// Use hierarchical matrices since we're testing the FLASH front-end.
156 	if ( storev == FLA_COLUMNWISE )
157 		FLASH_CAQR_UT_inc_create_hier_matrices( p, A, 1, &b_flash, b_alg_hier,
158 		                                        &A_test, &ATW_test, &R_test, &RTW_test );
159 	//else // if ( storev == FLA_ROWWISE )
160 	//  FLA_Check_error_code( FLA_NOT_YET_IMPLEMENTED );
161 	FLASH_Obj_create_hier_copy_of_flat( B, 1, &b_flash, &B_test );
162 	FLASH_Obj_create_hier_copy_of_flat( X, 1, &b_flash, &X_test );
163 	FLASH_Apply_CAQ_UT_inc_create_workspace( p, RTW_test, B_test, &W_test );
164 
165 	// Create a real scalar object to hold the norm of A.
166 	FLA_Obj_create( FLA_Obj_datatype_proj_to_real( A ), 1, 1, 0, 0, &norm );
167 
168 	// Save the original object contents in a temporary object.
169 	FLA_Obj_create_copy_of( FLA_NO_TRANSPOSE, B, &B_save );
170 
171 	// Compute a Householder factorization.
172 	if ( storev == FLA_COLUMNWISE ) FLASH_CAQR_UT_inc( p, A_test, ATW_test, R_test, RTW_test );
173 	//else                            FLA_Check_error_code( FLA_NOT_YET_IMPLEMENTED );
174 
175 	// Repeat the experiment n_repeats times and record results.
176 	for ( i = 0; i < n_repeats; ++i )
177 	{
178 		FLASH_Obj_hierarchify( B_save, B_test );
179 
180 		time = FLA_Clock();
181 
182 		libfla_test_apcaqutinc_impl( impl, p, side, trans, direct, storev,
183 		                             A_test, ATW_test, R_test, RTW_test, W_test, B_test );
184 
185 		time = FLA_Clock() - time;
186 		time_min = min( time_min, time );
187 	}
188 
189 	// Multiply by its conjugate-transpose to get what should be (near) identity
190 	// and then subtract from actual identity to get what should be (near) zero.
191 	if ( impl == FLA_TEST_HIER_FRONT_END )
192 	{
193 		FLA_Obj  RT, RB;
194 		FLA_Obj  BT, BB;
195 
196 		FLASH_Part_create_2x1( R_test,   &RT,
197 		                                 &RB,    FLASH_Obj_scalar_width( R_test ), FLA_TOP );
198 		FLASH_Part_create_2x1( B_test,   &BT,
199 		                                 &BB,    FLASH_Obj_scalar_width( R_test ), FLA_TOP );
200 
201 		FLASH_Trsm( FLA_LEFT, FLA_UPPER_TRIANGULAR, FLA_NO_TRANSPOSE, FLA_NONUNIT_DIAG,
202 		            FLA_ONE, RT, BT );
203 		FLASH_Copy( BT, X_test );
204 
205 		FLASH_Part_free_2x1( &RT,
206 		                     &RB );
207 		FLASH_Part_free_2x1( &BT,
208 		                     &BB );
209 
210 		FLASH_Obj_flatten( X_test, X );
211 
212 		FLA_Gemv_external( FLA_NO_TRANSPOSE, FLA_ONE, A, X, FLA_MINUS_ONE, B );
213 		FLA_Gemv_external( FLA_CONJ_TRANSPOSE, FLA_ONE, A, B, FLA_ZERO, Y );
214 	}
215 
216 	// Free the hierarchical matrices if we're testing the FLASH front-end.
217 	if ( impl == FLA_TEST_HIER_FRONT_END )
218 	{
219 		FLASH_Obj_free( &A_test );
220 		FLASH_Obj_free( &ATW_test );
221 		FLASH_Obj_free( &R_test );
222 		FLASH_Obj_free( &RTW_test );
223 		FLASH_Obj_free( &W_test );
224 		FLASH_Obj_free( &B_test );
225 		FLASH_Obj_free( &X_test );
226 	}
227 
228 	// Compute the norm of Y.
229 	FLA_Nrm2_external( Y, norm );
230 	FLA_Obj_extract_real_scalar( norm, residual );
231 
232 	// Compute the performance of the best experiment repeat.
233 	*perf = (  4.0 *       m * n * 1 -
234 	           2.0 *       n * n * 1 ) / time_min / FLOPS_PER_UNIT_PERF;
235 	if ( FLA_Obj_is_complex( A ) ) *perf *= 4.0;
236 
237 	// Free the supporting flat objects.
238 	FLA_Obj_free( &B_save );
239 
240 	// Free the flat test matrices.
241 	FLA_Obj_free( &A );
242 	FLA_Obj_free( &B );
243 	FLA_Obj_free( &X );
244 	FLA_Obj_free( &Y );
245 	FLA_Obj_free( &norm );
246 }
247 
248 
249 
libfla_test_apcaqutinc_impl(int impl,dim_t p,FLA_Side side,FLA_Trans trans,FLA_Direct direct,FLA_Store storev,FLA_Obj A,FLA_Obj ATW,FLA_Obj R,FLA_Obj RTW,FLA_Obj W,FLA_Obj B)250 void libfla_test_apcaqutinc_impl( int        impl,
251                                   dim_t      p,
252                                   FLA_Side   side,
253                                   FLA_Trans  trans,
254                                   FLA_Direct direct,
255                                   FLA_Store  storev,
256                                   FLA_Obj    A,
257                                   FLA_Obj    ATW,
258                                   FLA_Obj    R,
259                                   FLA_Obj    RTW,
260                                   FLA_Obj    W,
261                                   FLA_Obj    B )
262 {
263 	switch ( impl )
264 	{
265 		case FLA_TEST_HIER_FRONT_END:
266 		FLASH_Apply_CAQ_UT_inc( p, side, trans, direct, storev,
267 		                        A, ATW, R, RTW, W, B );
268 		break;
269 
270 		default:
271 		libfla_test_output_error( "Invalid implementation type.\n" );
272 	}
273 }
274 
275